Skip to content

Commit 5b39d84

Browse files
authored
update model management APIs (#564)
1 parent 081a0f4 commit 5b39d84

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

Samples/WindowsML/Resources/SqueezeNetModelCatalog.json

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
"base": "https://learn.microsoft.com/azure/ai-foundry/foundry-local/reference/reference-catalog-api",
33
"models": [
44
{
5-
"alias": "squeezenet",
5+
"id": "squeezenet",
66
"name": "squeezenet",
77
"version": "1",
8-
"modelType": "ONNX",
98
"publisher": "microsoft",
10-
"executionProvider": "CPUExecutionProvider",
11-
"description": "",
9+
"executionProviders": [
10+
{
11+
"name": "CPUExecutionProvider"
12+
}
13+
],
1214
"modelSizeBytes": 1250000,
1315
"license": "BSD",
1416
"licenseUri": "https://github.com/microsoft/WindowsAppSDK-Samples/raw/refs/heads/release/experimental/Samples/WindowsML/cpp/CppConsoleDesktop/CppConsoleDesktop/SqueezeNet.LICENSE.txt",
@@ -33,13 +35,15 @@
3335
]
3436
},
3537
{
36-
"alias": "squeezenet-fp32",
38+
"id": "squeezenet-fp32",
3739
"name": "squeezenet-fp32",
3840
"version": "1.1-7",
39-
"modelType": "ONNX",
4041
"publisher": "microsoft",
41-
"executionProvider": "NvTensorRTRTXExecutionProvider",
42-
"description": "",
42+
"executionProviders": [
43+
{
44+
"name": "NvTensorRTRTXExecutionProvider"
45+
}
46+
],
4347
"modelSizeBytes": 4730000,
4448
"license": "BSD",
4549
"licenseUri": "https://github.com/microsoft/WindowsAppSDK-Samples/blob/main/Samples/WindowsML/Resources/SqueezeNet.LICENSE.txt",

Samples/WindowsML/Shared/cs/ModelManager.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,33 +126,33 @@ public static string GetModelVariantPath(string executableFolder, ModelVariant v
126126
// Build source
127127
string sampleCatalogJsonPath = Path.Combine(executableFolder, "SqueezeNetModelCatalog.json");
128128
var uri = new System.Uri(sampleCatalogJsonPath);
129-
var sampleCatalogSource = await CatalogModelSource.CreateFromUri(uri);
130-
131-
WinMLModelCatalog modelCatalog = new WinMLModelCatalog(new[] { sampleCatalogSource });
132-
129+
var sampleCatalogSource = await ModelCatalogSource.CreateFromUriAsync(uri);
130+
131+
ModelCatalog modelCatalog = new ModelCatalog(new[] { sampleCatalogSource });
132+
133133
// Use intelligent model variant selection based on execution provider and device capabilities
134134
ModelVariant actualVariant = DetermineModelVariant(options, ortEnv);
135135

136136
CatalogModelInfo modelFromCatalog;
137137

138138
string modelVariantName = (actualVariant == ModelVariant.FP32) ? "squeezenet-fp32" : "squeezenet";
139139

140-
modelFromCatalog = await modelCatalog.FindModel(modelVariantName);
140+
modelFromCatalog = await modelCatalog.FindModelAsync(modelVariantName);
141141

142142
if (modelFromCatalog != null)
143143
{
144144
var additionalHeaders = new Dictionary<string, string>();
145-
var catalogModelInstanceOp = modelFromCatalog.GetInstance(additionalHeaders);
145+
var catalogModelInstanceOp = modelFromCatalog.GetInstanceAsync(additionalHeaders);
146146

147147
catalogModelInstanceOp.Progress += (operation, progress) => {
148148
Console.Write($"Model download progress: {progress}%\r");
149149
};
150150

151151
var catalogModelInstanceResult = await catalogModelInstanceOp;
152152

153-
if (catalogModelInstanceResult.Status == CatalogModelStatus.Available)
153+
if (catalogModelInstanceResult.Status == CatalogModelInstanceStatus.Available)
154154
{
155-
var catalogModelInstance = catalogModelInstanceResult.Instance;
155+
using var catalogModelInstance = catalogModelInstanceResult.GetInstance();
156156
var modelPaths = catalogModelInstance.ModelPaths;
157157

158158
string modelFolderPath = modelPaths[0];

Samples/WindowsML/cpp/CppConsoleDesktop/CppConsoleDesktop.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ IAsyncAction RunInferenceAsync(const CommandLineOptions& options)
4141

4242
auto sampleCatalogJsonPath = executableFolder / L"SqueezeNetModelCatalog.json";
4343
auto uri = winrt::Windows::Foundation::Uri(sampleCatalogJsonPath.c_str());
44-
auto sampleCatalogSource = winrt::Microsoft::Windows::AI::MachineLearning::CatalogModelSource::CreateFromUri(uri).get();
44+
auto sampleCatalogSource = winrt::Microsoft::Windows::AI::MachineLearning::ModelCatalogSource::CreateFromUriAsync(uri).get();
4545

46-
winrt::Microsoft::Windows::AI::MachineLearning::WinMLModelCatalog modelCatalog({sampleCatalogSource});
46+
winrt::Microsoft::Windows::AI::MachineLearning::ModelCatalog modelCatalog({sampleCatalogSource});
4747

4848
// Use intelligent model variant selection based on execution provider and device capabilities
4949
ModelVariant actualVariant = ModelManager::DetermineModelVariant(options, env);
@@ -52,21 +52,21 @@ IAsyncAction RunInferenceAsync(const CommandLineOptions& options)
5252

5353
std::wstring modelVariantName = (actualVariant == ModelVariant::FP32) ? L"squeezenet-fp32" : L"squeezenet";
5454

55-
modelFromCatalog = modelCatalog.FindModel(modelVariantName.c_str()).get();
55+
modelFromCatalog = modelCatalog.FindModelAsync(modelVariantName.c_str()).get();
5656

5757
if (modelFromCatalog != nullptr)
5858
{
59-
auto catalogModelInstanceOp = modelFromCatalog.GetInstance({});
59+
auto catalogModelInstanceOp = modelFromCatalog.GetInstanceAsync({});
6060

6161
catalogModelInstanceOp.Progress([](auto const& /*operation*/, double progress) {
6262
std::wcout << L"Model download progress: " << progress << L"%\r";
6363
});
6464

6565
auto catalogModelInstanceResult = co_await catalogModelInstanceOp;
6666

67-
if (catalogModelInstanceResult.Status() == winrt::Microsoft::Windows::AI::MachineLearning::CatalogModelStatus::Available)
67+
if (catalogModelInstanceResult.Status() == winrt::Microsoft::Windows::AI::MachineLearning::CatalogModelInstanceStatus::Available)
6868
{
69-
auto catalogModelInstance = catalogModelInstanceResult.Instance();
69+
auto catalogModelInstance = catalogModelInstanceResult.GetInstance();
7070
auto modelPaths = catalogModelInstance.ModelPaths();
7171

7272
auto modelFolderPath = std::filesystem::path(modelPaths.GetAt(0).c_str());

0 commit comments

Comments
 (0)