Skip to content

Commit 382a2b5

Browse files
areddishlmazuel
authored andcommitted
Add Custom Vision Training and Prediction Samples (#8)
* Add Custom Vision Training and Prediction Samples * Update readme, requirements to install and keys. * Add the samples dir to import path so imports can be found. * Add missing import
1 parent 5033504 commit 382a2b5

26 files changed

+108
-1
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ This project framework provides examples for the following services:
2424

2525
* Using the **Computer Vision SDK** [azure-cognitiveservices-vision-computervision](http://pypi.python.org/pypi/azure-cognitiveservices-vision-computervision) for the [Computer Vision API](https://azure.microsoft.com/services/cognitive-services/computer-vision/)
2626
* Using the **Content Moderator SDK** [azure-cognitiveservices-vision-contentmoderator](http://pypi.python.org/pypi/azure-cognitiveservices-vision-contentmoderator) for the [Content Moderator API](https://azure.microsoft.com/services/cognitive-services/content-moderator/)
27+
* Using the **Custom Vision SDK** [azure-cognitiveservices-vision-customvision](http://pypi.python.org/pypi/azure-cognitiveservices-vision-customvision) for the [Custom Vision API](https://azure.microsoft.com/services/cognitive-services/custom-vision-service/)
2728

2829
We provide several meta-packages to help you install several packages at a time. Please note that meta-packages are only recommended for development purpose. It's recommended in production to always pin specific version of individual packages.
2930

@@ -79,6 +80,8 @@ We provide several meta-packages to help you install several packages at a time.
7980
4. Set up the environment variable `WEBSEARCH_SUBSCRIPTION_KEY` with your key if you want to execute WebSearch tests.
8081
4. Set up the environment variable `COMPUTERVISION_SUBSCRIPTION_KEY` with your key if you want to execute Computer Vision tests. You might override too `COMPUTERVISION_LOCATION` (westcentralus by default).
8182
4. Set up the environment variable `CONTENTMODERATOR_SUBSCRIPTION_KEY` with your key if you want to execute Content Moderator tests. You might override too `CONTENTMODERATOR_LOCATION` (westcentralus by default).
83+
4. Set up the environment variable `CUSTOMVISION_TRAINING_KEY` with your key if you want to execute CustomVision Training tests.
84+
4. Set up the environment variable `CUSTOMVISION_PREDICTION_KEY` with your key if you want to execute CustomVision Prediction tests.
8285
8386
## Demo
8487

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ azure-cognitiveservices-search-newssearch
77
azure-cognitiveservices-search-videosearch
88
azure-cognitiveservices-search-websearch
99
azure-cognitiveservices-vision-computervision
10-
azure-cognitiveservices-vision-contentmoderator
10+
azure-cognitiveservices-vision-contentmoderator
11+
azure-cognitiveservices-vision-customvision
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
import sys
3+
4+
from azure.cognitiveservices.vision.customvision.training import training_api
5+
from azure.cognitiveservices.vision.customvision.prediction import prediction_endpoint
6+
from azure.cognitiveservices.vision.customvision.prediction.prediction_endpoint import models
7+
8+
TRAINING_KEY_ENV_NAME = "CUSTOMVISION_TRAINING_KEY"
9+
SUBSCRIPTION_KEY_ENV_NAME = "CUSTOMVISION_PREDICTION_KEY"
10+
11+
# Add this directory to the path so that custom_vision_training_samples can be found
12+
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "."))
13+
14+
IMAGES_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), "images")
15+
16+
def find_or_train_project():
17+
try:
18+
training_key = os.environ[TRAINING_KEY_ENV_NAME]
19+
except KeyError:
20+
raise SubscriptionKeyError("You need to set the {} env variable.".format(TRAINING_KEY_ENV_NAME))
21+
22+
# Use the training API to find the SDK sample project created from the training example.
23+
from custom_vision_training_samples import train_project, SAMPLE_PROJECT_NAME
24+
trainer = training_api.TrainingApi(training_key)
25+
26+
for proj in trainer.get_projects():
27+
if (proj.name == SAMPLE_PROJECT_NAME):
28+
return proj
29+
30+
# Or, if not found, we will run the training example to create it.
31+
return train_project(training_key)
32+
33+
def predict_project(subscription_key):
34+
predictor = prediction_endpoint.PredictionEndpoint(subscription_key)
35+
36+
# Find or train a new project to use for prediction.
37+
project = find_or_train_project()
38+
39+
with open(os.path.join(IMAGES_FOLDER, "Test", "test_image.jpg"), mode="rb") as test_data:
40+
results = predictor.predict_image(project.id, test_data.read())
41+
42+
# Display the results.
43+
for prediction in results.predictions:
44+
print ("\t" + prediction.tag + ": {0:.2f}%".format(prediction.probability * 100))
45+
46+
47+
if __name__ == "__main__":
48+
import sys, os.path
49+
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
50+
from tools import execute_samples, SubscriptionKeyError
51+
execute_samples(globals(), SUBSCRIPTION_KEY_ENV_NAME)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
import time
3+
4+
from azure.cognitiveservices.vision.customvision.training import training_api
5+
6+
SUBSCRIPTION_KEY_ENV_NAME = "CUSTOMVISION_TRAINING_KEY"
7+
SAMPLE_PROJECT_NAME = "Python SDK Sample"
8+
9+
IMAGES_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), "images")
10+
11+
def train_project(subscription_key):
12+
13+
trainer = training_api.TrainingApi(subscription_key)
14+
15+
# Create a new project
16+
print ("Creating project...")
17+
project = trainer.create_project(SAMPLE_PROJECT_NAME)
18+
19+
# Make two tags in the new project
20+
hemlock_tag = trainer.create_tag(project.id, "Hemlock")
21+
cherry_tag = trainer.create_tag(project.id, "Japanese Cherry")
22+
23+
print ("Adding images...")
24+
hemlock_dir = os.path.join(IMAGES_FOLDER, "Hemlock")
25+
for image in os.listdir(hemlock_dir):
26+
with open(os.path.join(hemlock_dir, image), mode="rb") as img_data:
27+
trainer.create_images_from_data(project.id, img_data.read(), [ hemlock_tag.id ])
28+
29+
cherry_dir = os.path.join(IMAGES_FOLDER, "Japanese Cherry")
30+
for image in os.listdir(cherry_dir):
31+
with open(os.path.join(cherry_dir, image), mode="rb") as img_data:
32+
trainer.create_images_from_data(project.id, img_data.read(), [ cherry_tag.id ])
33+
34+
print ("Training...")
35+
iteration = trainer.train_project(project.id)
36+
while (iteration.status == "Training"):
37+
iteration = trainer.get_iteration(project.id, iteration.id)
38+
print ("Training status: " + iteration.status)
39+
time.sleep(1)
40+
41+
# The iteration is now trained. Make it the default project endpoint
42+
trainer.update_iteration(project.id, iteration.id, is_default=True)
43+
print ("Done!")
44+
return project
45+
46+
if __name__ == "__main__":
47+
import sys, os.path
48+
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
49+
from tools import execute_samples
50+
execute_samples(globals(), SUBSCRIPTION_KEY_ENV_NAME)
263 KB
Loading
74 KB
Loading
244 KB
Loading
286 KB
Loading
166 KB
Loading
83 KB
Loading

0 commit comments

Comments
 (0)