Skip to content

Commit 2153585

Browse files
Merge pull request #43 from areddish/areddish/update-sample
Custom Vision: Update samples to work with latest sdk
2 parents db78074 + 7ddd7a9 commit 2153585

File tree

2 files changed

+37
-21
lines changed

2 files changed

+37
-21
lines changed

Vision/CustomVision/pom.xml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@
6161
<version>3.3</version>
6262
</dependency>
6363
<dependency>
64-
<groupId>com.microsoft.azure.cognitiveservices</groupId>
64+
<groupId>com.azure</groupId>
6565
<artifactId>azure-cognitiveservices-customvision-training</artifactId>
66-
<version>1.0.2-beta</version>
66+
<version>1.1.0-preview.2</version>
6767
</dependency>
6868
<dependency>
69-
<groupId>com.microsoft.azure.cognitiveservices</groupId>
69+
<groupId>com.azure</groupId>
7070
<artifactId>azure-cognitiveservices-customvision-prediction</artifactId>
71-
<version>1.0.2-beta</version>
71+
<version>1.1.0-preview.2</version>
7272
</dependency>
7373
</dependencies>
7474
</project>

Vision/CustomVision/src/main/java/com/microsoft/azure/cognitiveservices/vision/customvision/samples/CustomVisionSamples.java

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
import com.microsoft.azure.cognitiveservices.vision.customvision.training.models.Iteration;
2222
import com.microsoft.azure.cognitiveservices.vision.customvision.training.models.Project;
2323
import com.microsoft.azure.cognitiveservices.vision.customvision.training.models.Region;
24-
import com.microsoft.azure.cognitiveservices.vision.customvision.training.TrainingApi;
24+
import com.microsoft.azure.cognitiveservices.vision.customvision.training.models.TrainProjectOptionalParameter;
25+
import com.microsoft.azure.cognitiveservices.vision.customvision.training.CustomVisionTrainingClient;
2526
import com.microsoft.azure.cognitiveservices.vision.customvision.training.Trainings;
2627
import com.microsoft.azure.cognitiveservices.vision.customvision.training.CustomVisionTrainingManager;
2728
import com.microsoft.azure.cognitiveservices.vision.customvision.prediction.models.ImagePrediction;
2829
import com.microsoft.azure.cognitiveservices.vision.customvision.prediction.models.Prediction;
29-
import com.microsoft.azure.cognitiveservices.vision.customvision.prediction.PredictionEndpoint;
30+
import com.microsoft.azure.cognitiveservices.vision.customvision.prediction.CustomVisionPredictionClient;
3031
import com.microsoft.azure.cognitiveservices.vision.customvision.prediction.CustomVisionPredictionManager;
3132
import com.microsoft.azure.cognitiveservices.vision.customvision.training.models.Tag;
3233

@@ -37,7 +38,7 @@ public class CustomVisionSamples {
3738
* @param trainer the Custom Vision Training client object
3839
* @param predictor the Custom Vision Prediction client object
3940
*/
40-
public static void runSample(TrainingApi trainer, PredictionEndpoint predictor) {
41+
public static void runSample(CustomVisionTrainingClient trainer, CustomVisionPredictionClient predictor) {
4142
try {
4243
// This demonstrates how to create an image classification project, upload images,
4344
// train it and make a prediction.
@@ -52,7 +53,7 @@ public static void runSample(TrainingApi trainer, PredictionEndpoint predictor)
5253
}
5354
}
5455

55-
public static void ImageClassification_Sample(TrainingApi trainClient, PredictionEndpoint predictor) {
56+
public static void ImageClassification_Sample(CustomVisionTrainingClient trainClient, CustomVisionPredictionClient predictor) {
5657
try {
5758
// <snippet_create>
5859
System.out.println("ImageClassification Sample");
@@ -94,7 +95,7 @@ public static void ImageClassification_Sample(TrainingApi trainClient, Predictio
9495

9596
// <snippet_train>
9697
System.out.println("Training...");
97-
Iteration iteration = trainer.trainProject(project.id());
98+
Iteration iteration = trainer.trainProject(project.id(), new TrainProjectOptionalParameter());
9899

99100
while (iteration.status().equals("Training"))
100101
{
@@ -103,13 +104,18 @@ public static void ImageClassification_Sample(TrainingApi trainClient, Predictio
103104
iteration = trainer.getIteration(project.id(), iteration.id());
104105
}
105106
System.out.println("Training Status: "+ iteration.status());
106-
trainer.updateIteration(project.id(), iteration.id(), iteration.withIsDefault(true));
107+
108+
// The iteration is now trained. Publish it to the prediction endpoint.
109+
String publishedModelName = "myModel";
110+
String predictionResourceId = System.getenv("AZURE_CUSTOMVISION_PREDICTION_ID");
111+
trainer.publishIteration(project.id(), iteration.id(), publishedModelName, predictionResourceId);
107112
// </snippet_train>
108113

109114
// use below for url
110115
// String url = "some url";
111-
// ImagePrediction results = predictor.predictions().predictImage()
116+
// ImagePrediction results = predictor.predictions().classifyImageUrl()
112117
// .withProjectId(project.id())
118+
// .withPublishedName(publishedModelName)
113119
// .withUrl(url)
114120
// .execute();
115121

@@ -118,8 +124,9 @@ public static void ImageClassification_Sample(TrainingApi trainClient, Predictio
118124
byte[] testImage = GetImage("/Test", "test_image.jpg");
119125

120126
// predict
121-
ImagePrediction results = predictor.predictions().predictImage()
127+
ImagePrediction results = predictor.predictions().classifyImage()
122128
.withProjectId(project.id())
129+
.withPublishedName(publishedModelName)
123130
.withImageData(testImage)
124131
.execute();
125132

@@ -134,7 +141,7 @@ public static void ImageClassification_Sample(TrainingApi trainClient, Predictio
134141
}
135142
}
136143

137-
public static void ObjectDetection_Sample(TrainingApi trainClient, PredictionEndpoint predictor)
144+
public static void ObjectDetection_Sample(CustomVisionTrainingClient trainClient, CustomVisionPredictionClient predictor)
138145
{
139146
try {
140147
// <snippet_od_mapping>
@@ -250,21 +257,27 @@ public static void ObjectDetection_Sample(TrainingApi trainClient, PredictionEnd
250257

251258
// <snippet_train_od>
252259
System.out.println("Training...");
253-
Iteration iteration = trainer.trainProject(project.id());
260+
Iteration iteration = trainer.trainProject(project.id(), new TrainProjectOptionalParameter());
261+
254262
while (iteration.status().equals("Training"))
255263
{
256264
System.out.println("Training Status: "+ iteration.status());
257265
Thread.sleep(5000);
258266
iteration = trainer.getIteration(project.id(), iteration.id());
259267
}
260268
System.out.println("Training Status: "+ iteration.status());
261-
trainer.updateIteration(project.id(), iteration.id(), iteration.withIsDefault(true));
269+
270+
// The iteration is now trained. Publish it to the prediction endpoint.
271+
String publishedModelName = "myModel";
272+
String predictionResourceId = System.getenv("AZURE_CUSTOMVISION_PREDICTION_ID");
273+
trainer.publishIteration(project.id(), iteration.id(), publishedModelName, predictionResourceId);
262274
// </snippet_train_od>
263275

264276
// use below for url
265277
// String url = "some url";
266-
// ImagePrediction results = predictor.predictions().predictImage()
278+
// ImagePrediction results = predictor.predictions().detectImageUrl()
267279
// .withProjectId(project.id())
280+
// .withPublishedName(publishedModelName)
268281
// .withUrl(url)
269282
// .execute();
270283

@@ -273,8 +286,9 @@ public static void ObjectDetection_Sample(TrainingApi trainClient, PredictionEnd
273286
byte[] testImage = GetImage("/ObjectTest", "test_image.jpg");
274287

275288
// predict
276-
ImagePrediction results = predictor.predictions().predictImage()
289+
ImagePrediction results = predictor.predictions().detectImage()
277290
.withProjectId(project.id())
291+
.withPublishedName(publishedModelName)
278292
.withImageData(testImage)
279293
.execute();
280294

@@ -347,11 +361,13 @@ public static void main(String[] args) {
347361
//=============================================================
348362
// Authenticate
349363

350-
final String trainingApiKey = System.getenv("AZURE_CUSTOMVISION_TRAINING_API_KEY");;
351-
final String predictionApiKey = System.getenv("AZURE_CUSTOMVISION_PREDICTION_API_KEY");;
364+
final String CustomVisionTrainingClientKey = System.getenv("AZURE_CUSTOMVISION_TRAINING_API_KEY");
365+
final String predictionApiKey = System.getenv("AZURE_CUSTOMVISION_PREDICTION_API_KEY");
366+
367+
final String Endpoint = System.getenv("AZURE_CUSTOMVISION_ENDPOINT");
352368

353-
TrainingApi trainClient = CustomVisionTrainingManager.authenticate(trainingApiKey);
354-
PredictionEndpoint predictClient = CustomVisionPredictionManager.authenticate(predictionApiKey);
369+
CustomVisionTrainingClient trainClient = CustomVisionTrainingManager.authenticate("https://{Endpoint}/customvision/v3.0/training/", CustomVisionTrainingClientKey).withEndpoint(Endpoint);
370+
CustomVisionPredictionClient predictClient = CustomVisionPredictionManager.authenticate("https://{Endpoint}/customvision/v3.0/prediction/", predictionApiKey).withEndpoint(Endpoint);
355371

356372
runSample(trainClient, predictClient);
357373
} catch (Exception e) {

0 commit comments

Comments
 (0)