1+ import os
2+ import time
3+
4+ from azure .cognitiveservices .vision .customvision .training import training_api
5+ from azure .cognitiveservices .vision .customvision .training .models import Classifier
6+
7+ SUBSCRIPTION_KEY_ENV_NAME = "CUSTOMVISION_TRAINING_KEY"
8+ SAMPLE_PROJECT_NAME = "Python SDK Sample"
9+
10+ IMAGES_FOLDER = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "images" )
11+
12+ def train_project (subscription_key ):
13+
14+ trainer = training_api .TrainingApi (subscription_key )
15+
16+ # Create a new project
17+ print ("Creating project..." )
18+ project = trainer .create_project (SAMPLE_PROJECT_NAME , classification_type = Classifier .multiclass )
19+
20+ # Make two tags in the new project
21+ hemlock_tag = trainer .create_tag (project .id , "Hemlock" )
22+ cherry_tag = trainer .create_tag (project .id , "Japanese Cherry" )
23+ pine_needle_tag = trainer .create_tag (project .id , "Pine Needle Leaves" )
24+ flat_leaf_tag = trainer .create_tag (project .id , "Flat Leaves" )
25+
26+ print ("Adding images..." )
27+ hemlock_dir = os .path .join (IMAGES_FOLDER , "Hemlock" )
28+ for image in os .listdir (hemlock_dir ):
29+ with open (os .path .join (hemlock_dir , image ), mode = "rb" ) as img_data :
30+ trainer .create_images_from_data (project .id , img_data .read (), [ hemlock_tag .id , pine_needle_tag .id ])
31+
32+ cherry_dir = os .path .join (IMAGES_FOLDER , "Japanese Cherry" )
33+ for image in os .listdir (cherry_dir ):
34+ with open (os .path .join (cherry_dir , image ), mode = "rb" ) as img_data :
35+ trainer .create_images_from_data (project .id , img_data .read (), [ cherry_tag .id , flat_leaf_tag .id ])
36+
37+ print ("Training..." )
38+ iteration = trainer .train_project (project .id )
39+ while (iteration .status == "Training" ):
40+ iteration = trainer .get_iteration (project .id , iteration .id )
41+ print ("Training status: " + iteration .status )
42+ time .sleep (1 )
43+
44+ # The iteration is now trained. Make it the default project endpoint
45+ trainer .update_iteration (project .id , iteration .id , is_default = True )
46+ print ("Done!" )
47+ return project
48+
49+ if __name__ == "__main__" :
50+ import sys , os .path
51+ sys .path .append (os .path .abspath (os .path .join (__file__ , ".." , ".." )))
52+ from tools import execute_samples
53+ execute_samples (globals (), SUBSCRIPTION_KEY_ENV_NAME )
0 commit comments