Skip to content

Commit 951d588

Browse files
Merge pull request #111816 from bhimar/patch-2
Update how-to-auto-train-image-models.md
2 parents 2d6b615 + cc09f8b commit 951d588

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

articles/machine-learning/how-to-auto-train-image-models.md

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ validation_data:
222222

223223
# [Python SDK](#tab/python)
224224

225-
[!INCLUDE [sdk v2](includes/machine-learning-sdk-v2.md)]
225+
[!INCLUDE [sdk v2](includes/machine-learning-sdk-v2.md)]
226226

227227
You can create data inputs from training and validation MLTable from your local directory or cloud storage with the following code:
228228

@@ -359,14 +359,49 @@ In individual trials, you directly control the model architecture and hyperparam
359359

360360
#### Supported model architectures
361361

362-
The following table summarizes the supported models for each computer vision task.
362+
The following table summarizes the supported legacy models for each computer vision task. Using only these legacy models will trigger runs using the legacy runtime (where each individual run or trial is submitted as a command job). Please see below for HuggingFace and MMDetection support.
363363

364364
Task | model architectures | String literal syntax<br> ***`default_model`\**** denoted with \*
365365
---|----------|----------
366366
Image classification<br> (multi-class and multi-label)| **MobileNet**: Light-weighted models for mobile applications <br> **ResNet**: Residual networks<br> **ResNeSt**: Split attention networks<br> **SE-ResNeXt50**: Squeeze-and-Excitation networks<br> **ViT**: Vision transformer networks| `mobilenetv2` <br>`resnet18` <br>`resnet34` <br> `resnet50` <br> `resnet101` <br> `resnet152` <br> `resnest50` <br> `resnest101` <br> `seresnext` <br> `vits16r224` (small) <br> ***`vitb16r224`\**** (base) <br>`vitl16r224` (large)|
367367
Object detection | **YOLOv5**: One stage object detection model <br> **Faster RCNN ResNet FPN**: Two stage object detection models <br> **RetinaNet ResNet FPN**: address class imbalance with Focal Loss <br> <br>*Note: Refer to [`model_size` hyperparameter](reference-automl-images-hyperparameters.md#model-specific-hyperparameters) for YOLOv5 model sizes.*| ***`yolov5`\**** <br> `fasterrcnn_resnet18_fpn` <br> `fasterrcnn_resnet34_fpn` <br> `fasterrcnn_resnet50_fpn` <br> `fasterrcnn_resnet101_fpn` <br> `fasterrcnn_resnet152_fpn` <br> `retinanet_resnet50_fpn`
368368
Instance segmentation | **MaskRCNN ResNet FPN**| `maskrcnn_resnet18_fpn` <br> `maskrcnn_resnet34_fpn` <br> ***`maskrcnn_resnet50_fpn`\**** <br> `maskrcnn_resnet101_fpn` <br> `maskrcnn_resnet152_fpn`
369369

370+
#### Supported model architectures - HuggingFace and MMDetection (preview)
371+
372+
With the new backend that runs on [Azure Machine Learning pipelines](concept-ml-pipelines.md), you can additionally use any image classification model from the [HuggingFace Hub](https://huggingface.co/models?pipeline_tag=image-classification&library=transformers) which is part of the transformers library (such as microsoft/beit-base-patch16-224), as well as any object detection or instance segmentation model from the [MMDetection Version 2.28.2 Model Zoo](https://mmdetection.readthedocs.io/en/v2.28.2/model_zoo.html) (such as atss_r50_fpn_1x_coco).
373+
374+
In addition to supporting any model from HuggingFace Transfomers and MMDetection 2.28.2, we also offer a list of curated models from these libraries in the azureml-staging registry. These curated models have been tested thoroughly and use default hyperparameters selected from extensive benchmarking to ensure effective training. The table below summarizes these curated models.
375+
376+
Task | model architectures | String literal syntax
377+
---|----------|----------
378+
Image classification<br> (multi-class and multi-label)| **BEiT** <br> **ViT** <br> **DeiT** <br> **SwinV2]** | [`microsoft/beit-base-patch16-224-pt22k-ft22k`](https://ml.azure.com/registries/azureml/models/microsoft-beit-base-patch16-224-pt22k-ft22k/version/5)<br> [`google/vit-base-patch16-224`](https://ml.azure.com/registries/azureml/models/google-vit-base-patch16-224/version/5)<br> [`facebook/deit-base-patch16-224`](https://ml.azure.com/registries/azureml/models/facebook-deit-base-patch16-224/version/5)<br> [`microsoft/swinv2-base-patch4-window12-192-22k`](https://ml.azure.com/registries/azureml/models/microsoft-swinv2-base-patch4-window12-192-22k/version/5)
379+
Object Detection | **Sparse R-CNN** <br> **Deformable DETR** <br> **VFNet** <br> **YOLOF** <br> **Swin** | [`sparse_rcnn_r50_fpn_300_proposals_crop_mstrain_480-800_3x_coco`](https://ml.azure.com/registries/azureml/models/sparse_rcnn_r50_fpn_300_proposals_crop_mstrain_480-800_3x_coco/version/3)<br> [`sparse_rcnn_r101_fpn_300_proposals_crop_mstrain_480-800_3x_coco`](https://ml.azure.com/registries/azureml/models/sparse_rcnn_r101_fpn_300_proposals_crop_mstrain_480-800_3x_coco/version/3) <br> [`deformable_detr_twostage_refine_r50_16x2_50e_coco`](https://ml.azure.com/registries/azureml/models/deformable_detr_twostage_refine_r50_16x2_50e_coco/version/3) <br> [`vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco`](https://ml.azure.com/registries/azureml/models/vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco/version/3) <br> [`vfnet_x101_64x4d_fpn_mdconv_c3-c5_mstrain_2x_coco`](https://ml.azure.com/registries/azureml/models/vfnet_x101_64x4d_fpn_mdconv_c3-c5_mstrain_2x_coco/version/3) <br> [`yolof_r50_c5_8x8_1x_coco`](https://ml.azure.com/registries/azureml/models/yolof_r50_c5_8x8_1x_coco/version/3)
380+
Instance Segmentation | **Swin** | [`mask_rcnn_swin-t-p4-w7_fpn_1x_coco`](https://ml.azure.com/registries/azureml/models/mask_rcnn_swin-t-p4-w7_fpn_1x_coco/version/3)
381+
382+
We constantly update the list of curated models. You can get the most up-to-date list of the curated models for a given task using the Python SDK:
383+
```
384+
credential = DefaultAzureCredential()
385+
ml_client = MLClient(credential, registry_name="azureml-staging")
386+
387+
models = ml_client.models.list()
388+
classification_models = []
389+
for model in models:
390+
model = ml_client.models.get(model.name, label="latest")
391+
if model.tags['task'] == 'image-classification': # choose an image task
392+
classification_models.append(model.name)
393+
394+
classification_models
395+
```
396+
Output:
397+
```
398+
['google-vit-base-patch16-224',
399+
'microsoft-swinv2-base-patch4-window12-192-22k',
400+
'facebook-deit-base-patch16-224',
401+
'microsoft-beit-base-patch16-224-pt22k-ft22k']
402+
```
403+
Using any HuggingFace or MMDetection model will trigger runs using pipeline components. If both legacy and HuggingFace/MMdetection models are used, all runs/trials will be triggered using components.
404+
370405

371406
In addition to controlling the model architecture, you can also tune hyperparameters used for model training. While many of the hyperparameters exposed are model-agnostic, there are instances where hyperparameters are task-specific or model-specific. [Learn more about the available hyperparameters for these instances](reference-automl-images-hyperparameters.md).
372407

0 commit comments

Comments
 (0)