Skip to content

Commit 0c47594

Browse files
natkeJRAlexander
authored andcommitted
Remove unnecessary training data from Fit() methods (#1129)
1 parent 33a9867 commit 0c47594

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
Month,ProductSales
2+
1-Jan,271
3+
2-Jan,150.9
4+
3-Jan,188.1
5+
4-Jan,124.3
6+
5-Jan,185.3
7+
6-Jan,173.5
8+
7-Jan,236.8
9+
8-Jan,229.5
10+
9-Jan,197.8
11+
10-Jan,127.9
12+
11-Jan,341.5
13+
12-Jan,190.9
14+
1-Feb,199.3
15+
2-Feb,154.5
16+
3-Feb,215.1
17+
4-Feb,278.3
18+
5-Feb,196.4
19+
6-Feb,292
20+
7-Feb,231
21+
8-Feb,308.6
22+
9-Feb,294.9
23+
10-Feb,426.6
24+
11-Feb,269.5
25+
12-Feb,347.3
26+
1-Mar,344.7
27+
2-Mar,445.4
28+
3-Mar,320.9
29+
4-Mar,444.3
30+
5-Mar,406.3
31+
6-Mar,442.4
32+
7-Mar,580.5
33+
8-Mar,412.6
34+
9-Mar,687
35+
10-Mar,480.3
36+
11-Mar,586.3
37+
12-Mar,651.9

machine-learning/tutorials/ProductSalesAnomalyDetection/Program.cs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using System.IO;
44
using Microsoft.ML;
5+
using System.Collections.Generic;
56
// </SnippetAddUsings>
67

78
namespace ProductSalesAnomalyDetection
@@ -44,17 +45,17 @@ static void DetectSpike(MLContext mlContext, int docSize, IDataView productSales
4445
var iidSpikeEstimator = mlContext.Transforms.DetectIidSpike(outputColumnName: nameof(ProductSalesPrediction.Prediction), inputColumnName: nameof(ProductSalesData.numSales), confidence: 95, pvalueHistoryLength: docSize / 4);
4546
// </SnippetAddSpikeTrainer>
4647

47-
// STEP 3:Train the model by fitting the dataview
48-
// Create and train the model based on the dataset that has been loaded, transformed.
48+
// STEP 3: Create the transform
49+
// Create the spike detection transform
4950
Console.WriteLine("=============== Training the model ===============");
5051
// <SnippetTrainModel1>
51-
ITransformer trainedModel = iidSpikeEstimator.Fit(productSales);
52+
ITransformer iidSpikeTransform = iidSpikeEstimator.Fit(CreateEmptyDataView(mlContext));
5253
// </SnippetTrainModel1>
5354

5455
Console.WriteLine("=============== End of training process ===============");
5556
//Apply data transformation to create predictions.
5657
// <SnippetTransformData1>
57-
IDataView transformedData = trainedModel.Transform(productSales);
58+
IDataView transformedData = iidSpikeTransform.Transform(productSales);
5859
// </SnippetTransformData1>
5960

6061
// <SnippetCreateEnumerable1>
@@ -90,16 +91,16 @@ static void DetectChangepoint(MLContext mlContext, int docSize, IDataView produc
9091
var iidChangePointEstimator = mlContext.Transforms.DetectIidChangePoint(outputColumnName: nameof(ProductSalesPrediction.Prediction), inputColumnName: nameof(ProductSalesData.numSales), confidence: 95, changeHistoryLength: docSize / 4);
9192
// </SnippetAddChangePointTrainer>
9293

93-
//STEP 3:Train the model by fitting the dataview
94+
//STEP 3: Create the transform
9495
Console.WriteLine("=============== Training the model Using Change Point Detection Algorithm===============");
9596
// <SnippetTrainModel2>
96-
var trainedModel = iidChangePointEstimator.Fit(productSales);
97+
var iidChangePointTransform = iidChangePointEstimator.Fit(CreateEmptyDataView(mlContext));
9798
// </SnippetTrainModel2>
9899
Console.WriteLine("=============== End of training process ===============");
99100

100101
//Apply data transformation to create predictions.
101102
// <SnippetTransformData2>
102-
IDataView transformedData = trainedModel.Transform(productSales);
103+
IDataView transformedData = iidChangePointTransform.Transform(productSales);
103104
// </SnippetTransformData2>
104105

105106
// <SnippetCreateEnumerable2>
@@ -124,5 +125,13 @@ static void DetectChangepoint(MLContext mlContext, int docSize, IDataView produc
124125
Console.WriteLine("");
125126
// </SnippetDisplayResults2>
126127
}
128+
129+
// <SnippetCreateEmptyDataView>
130+
static IDataView CreateEmptyDataView(MLContext mlContext) {
131+
// Create empty DataView. We just need the schema to call Fit() for the time series transforms
132+
IEnumerable<ProductSalesData> enumerableData = new List<ProductSalesData>();
133+
return mlContext.Data.LoadFromEnumerable(enumerableData);
134+
}
135+
// </SnippetCreateEmptyDataView>
127136
}
128137
}

0 commit comments

Comments
 (0)