Skip to content

Commit d2aa20f

Browse files
areddishlmazuel
authored andcommitted
Add a sample for Multiclass training (#12)
* Add a sample for Multiclass training * Update call to assign arg
1 parent f3a71b8 commit d2aa20f

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
import time
3+
4+
from azure.cognitiveservices.vision.customvision.training import training_api
5+
from azure.cognitiveservices.vision.customvision.training.models import Classifier
6+
7+
SUBSCRIPTION_KEY_ENV_NAME = "CUSTOMVISION_TRAINING_KEY"
8+
SAMPLE_PROJECT_NAME = "Python SDK Sample"
9+
10+
IMAGES_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), "images")
11+
12+
def train_project(subscription_key):
13+
14+
trainer = training_api.TrainingApi(subscription_key)
15+
16+
# Create a new project
17+
print ("Creating project...")
18+
project = trainer.create_project(SAMPLE_PROJECT_NAME, classification_type=Classifier.multiclass)
19+
20+
# Make two tags in the new project
21+
hemlock_tag = trainer.create_tag(project.id, "Hemlock")
22+
cherry_tag = trainer.create_tag(project.id, "Japanese Cherry")
23+
pine_needle_tag = trainer.create_tag(project.id, "Pine Needle Leaves")
24+
flat_leaf_tag = trainer.create_tag(project.id, "Flat Leaves")
25+
26+
print ("Adding images...")
27+
hemlock_dir = os.path.join(IMAGES_FOLDER, "Hemlock")
28+
for image in os.listdir(hemlock_dir):
29+
with open(os.path.join(hemlock_dir, image), mode="rb") as img_data:
30+
trainer.create_images_from_data(project.id, img_data.read(), [ hemlock_tag.id, pine_needle_tag.id ])
31+
32+
cherry_dir = os.path.join(IMAGES_FOLDER, "Japanese Cherry")
33+
for image in os.listdir(cherry_dir):
34+
with open(os.path.join(cherry_dir, image), mode="rb") as img_data:
35+
trainer.create_images_from_data(project.id, img_data.read(), [ cherry_tag.id, flat_leaf_tag.id ])
36+
37+
print ("Training...")
38+
iteration = trainer.train_project(project.id)
39+
while (iteration.status == "Training"):
40+
iteration = trainer.get_iteration(project.id, iteration.id)
41+
print ("Training status: " + iteration.status)
42+
time.sleep(1)
43+
44+
# The iteration is now trained. Make it the default project endpoint
45+
trainer.update_iteration(project.id, iteration.id, is_default=True)
46+
print ("Done!")
47+
return project
48+
49+
if __name__ == "__main__":
50+
import sys, os.path
51+
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
52+
from tools import execute_samples
53+
execute_samples(globals(), SUBSCRIPTION_KEY_ENV_NAME)

0 commit comments

Comments
 (0)