Skip to content

Commit e6dabc5

Browse files
sagarsumantSagar Sumant
andauthored
Add MaaP Finetuning Job implementation. (Azure#38650)
* WIP * Fix circular import * Add queue_settings and resources fields. * Work in progress. * Test run successful * All 3 flavors working in integration tests. * add unit tests for yaml. * Add tests for yaml job creation. * Fix tests. * Fix test. * Refactor and update tests. * Fix things. * Fix tests. * Fix failing test. * Fix tests. * add @pytest.mark.e2etest * try recorded test. * comment status for non live testing. * Fix things. * Rebase and fix merge issues. * format using black. * format using black. * Disable recorded test execution * Remove assets.json changes. * Fix linting errors. * Fix imports * Fix comments. * Remove job.py changes. * Fix tests. --------- Co-authored-by: Sagar Sumant <[email protected]>
1 parent 262b16a commit e6dabc5

28 files changed

+1330
-285
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_ml_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ def __init__(
709709
_service_client_kwargs=kwargs,
710710
requests_pipeline=self._requests_pipeline,
711711
service_client_01_2024_preview=self._service_client_01_2024_preview,
712+
service_client_10_2024_preview=self._service_client_10_2024_preview,
712713
**ops_kwargs,
713714
)
714715
self._operation_container.add(AzureMLResourceType.JOB, self._jobs)
@@ -746,7 +747,8 @@ def __init__(
746747
**ops_kwargs, # type: ignore[arg-type]
747748
)
748749
self._operation_container.add(
749-
AzureMLResourceType.VIRTUALCLUSTER, self._virtual_clusters # type: ignore[arg-type]
750+
AzureMLResourceType.VIRTUALCLUSTER,
751+
self._virtual_clusters, # type: ignore[arg-type]
750752
)
751753
except Exception as ex: # pylint: disable=broad-except
752754
module_logger.debug("Virtual Cluster operations could not be initialized due to %s ", ex)

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_finetuning/finetuning_job.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,17 @@
77
from azure.ai.ml._schema.job import BaseJobSchema
88
from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField
99
from azure.ai.ml._utils._experimental import experimental
10+
from azure.ai.ml._schema.core.fields import (
11+
NestedField,
12+
)
13+
from ..queue_settings import QueueSettingsSchema
14+
from ..job_resources import JobResourcesSchema
1015

1116
# This is meant to match the yaml definition NOT the models defined in _restclient
1217

1318

1419
@experimental
1520
class FineTuningJobSchema(BaseJobSchema):
1621
outputs = OutputsField()
22+
queue_settings = NestedField(QueueSettingsSchema)
23+
resources = NestedField(JobResourcesSchema)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
# pylint: disable=unused-argument
6+
7+
from marshmallow import fields, post_load
8+
9+
from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
10+
11+
12+
class JobResourcesSchema(metaclass=PatchedSchemaMeta):
13+
instance_types = fields.List(
14+
fields.Str(), metadata={"description": "The instance type to make available to this job."}
15+
)
16+
17+
@post_load
18+
def make(self, data, **kwargs):
19+
from azure.ai.ml.entities import JobResources
20+
21+
return JobResources(**data)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
6+
class FineTuningTaskType:
7+
CHAT_COMPLETION = "ChatCompletion"
8+
TEXT_COMPLETION = "TextCompletion"
9+
TEXT_CLASSIFICATION = "TextClassification"
10+
QUESTION_ANSWERING = "QuestionAnswering"
11+
TEXT_SUMMARIZATION = "TextSummarization"
12+
TOKEN_CLASSIFICATION = "TokenClassification"
13+
TEXT_TRANSLATION = "TextTranslation"
14+
IMAGE_CLASSIFICATION = "ImageClassification"
15+
IMAGE_INSTANCE_SEGMENTATION = "ImageInstanceSegmentation"
16+
IMAGE_OBJECT_DETECTION = "ImageObjectDetection"
17+
VIDEO_MULTI_OBJECT_TRACKING = "VideoMultiObjectTracking"

sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,29 @@
5252
from ._component.pipeline_component import PipelineComponent
5353
from ._component.spark_component import SparkComponent
5454
from ._compute._aml_compute_node_info import AmlComputeNodeInfo
55-
from ._compute._custom_applications import CustomApplications, EndpointsSettings, ImageSettings, VolumeSettings
55+
from ._compute._custom_applications import (
56+
CustomApplications,
57+
EndpointsSettings,
58+
ImageSettings,
59+
VolumeSettings,
60+
)
5661
from ._compute._image_metadata import ImageMetadata
57-
from ._compute._schedule import ComputePowerAction, ComputeSchedules, ComputeStartStopSchedule, ScheduleState
62+
from ._compute._schedule import (
63+
ComputePowerAction,
64+
ComputeSchedules,
65+
ComputeStartStopSchedule,
66+
ScheduleState,
67+
)
5868
from ._compute._setup_scripts import ScriptReference, SetupScripts
5969
from ._compute._usage import Usage, UsageName
6070
from ._compute._vm_size import VmSize
6171
from ._compute.aml_compute import AmlCompute, AmlComputeSshSettings
6272
from ._compute.compute import Compute, NetworkSettings
63-
from ._compute.compute_instance import AssignedUserConfiguration, ComputeInstance, ComputeInstanceSshSettings
73+
from ._compute.compute_instance import (
74+
AssignedUserConfiguration,
75+
ComputeInstance,
76+
ComputeInstanceSshSettings,
77+
)
6478
from ._compute.kubernetes_compute import KubernetesCompute
6579
from ._compute.synapsespark_compute import AutoPauseSettings, AutoScaleSettings, SynapseSparkCompute
6680
from ._compute.unsupported_compute import UnsupportedCompute
@@ -84,7 +98,11 @@
8498
from ._data_import.data_import import DataImport
8599
from ._data_import.schedule import ImportDataSchedule
86100
from ._datastore.adls_gen1 import AzureDataLakeGen1Datastore
87-
from ._datastore.azure_storage import AzureBlobDatastore, AzureDataLakeGen2Datastore, AzureFileDatastore
101+
from ._datastore.azure_storage import (
102+
AzureBlobDatastore,
103+
AzureDataLakeGen2Datastore,
104+
AzureFileDatastore,
105+
)
88106
from ._datastore.datastore import Datastore
89107
from ._datastore.one_lake import OneLakeArtifact, OneLakeDatastore
90108
from ._deployment.batch_deployment import BatchDeployment
@@ -94,7 +112,11 @@
94112
from ._deployment.data_asset import DataAsset
95113
from ._deployment.data_collector import DataCollector
96114
from ._deployment.deployment_collection import DeploymentCollection
97-
from ._deployment.deployment_settings import BatchRetrySettings, OnlineRequestSettings, ProbeSettings
115+
from ._deployment.deployment_settings import (
116+
BatchRetrySettings,
117+
OnlineRequestSettings,
118+
ProbeSettings,
119+
)
98120
from ._deployment.model_batch_deployment import ModelBatchDeployment
99121
from ._deployment.model_batch_deployment_settings import ModelBatchDeploymentSettings
100122
from ._deployment.online_deployment import (
@@ -106,7 +128,11 @@
106128
from ._deployment.pipeline_component_batch_deployment import PipelineComponentBatchDeployment
107129
from ._deployment.request_logging import RequestLogging
108130
from ._deployment.resource_requirements_settings import ResourceRequirementsSettings
109-
from ._deployment.scale_settings import DefaultScaleSettings, OnlineScaleSettings, TargetUtilizationScaleSettings
131+
from ._deployment.scale_settings import (
132+
DefaultScaleSettings,
133+
OnlineScaleSettings,
134+
TargetUtilizationScaleSettings,
135+
)
110136
from ._endpoint.batch_endpoint import BatchEndpoint
111137
from ._endpoint.endpoint import Endpoint
112138
from ._endpoint.online_endpoint import (
@@ -136,11 +162,19 @@
136162
from ._indexes import ModelConfiguration as IndexModelConfiguration
137163
from ._job.command_job import CommandJob
138164
from ._job.compute_configuration import ComputeConfiguration
165+
from ._job.finetuning.custom_model_finetuning_job import CustomModelFineTuningJob
139166
from ._job.input_port import InputPort
140167
from ._job.job import Job
141168
from ._job.job_limits import CommandJobLimits
169+
from ._job.job_resources import JobResources
142170
from ._job.job_resource_configuration import JobResourceConfiguration
143-
from ._job.job_service import JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService
171+
from ._job.job_service import (
172+
JobService,
173+
JupyterLabJobService,
174+
SshJobService,
175+
TensorBoardJobService,
176+
VsCodeJobService,
177+
)
144178
from ._job.parallel.parallel_task import ParallelTask
145179
from ._job.parallel.retry_settings import RetrySettings
146180
from ._job.parameterized_command import ParameterizedCommand
@@ -156,7 +190,12 @@
156190
from ._monitoring.alert_notification import AlertNotification
157191
from ._monitoring.compute import ServerlessSparkCompute
158192
from ._monitoring.definition import MonitorDefinition
159-
from ._monitoring.input_data import FixedInputData, MonitorInputData, StaticInputData, TrailingInputData
193+
from ._monitoring.input_data import (
194+
FixedInputData,
195+
MonitorInputData,
196+
StaticInputData,
197+
TrailingInputData,
198+
)
160199
from ._monitoring.schedule import MonitorSchedule
161200
from ._monitoring.signals import (
162201
BaselineDataRange,
@@ -244,7 +283,11 @@
244283
from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint
245284
from ._workspace.serverless_compute import ServerlessComputeSettings
246285
from ._workspace.workspace import Workspace
247-
from ._workspace.workspace_keys import ContainerRegistryCredential, NotebookAccessKeys, WorkspaceKeys
286+
from ._workspace.workspace_keys import (
287+
ContainerRegistryCredential,
288+
NotebookAccessKeys,
289+
WorkspaceKeys,
290+
)
248291

249292
__all__ = [
250293
"Resource",
@@ -258,8 +301,10 @@
258301
"SparkJobEntryType",
259302
"CommandJobLimits",
260303
"ComputeConfiguration",
304+
"CustomModelFineTuningJob",
261305
"CreatedByType",
262306
"ResourceConfiguration",
307+
"JobResources",
263308
"JobResourceConfiguration",
264309
"QueueSettings",
265310
"JobService",

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/_input_output_helpers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ def validate_pipeline_input_key_characters(key: str) -> None:
205205
# so a valid pipeline key is: ^{single_key}([.]{single_key})*$
206206
if re.match(IOConstants.VALID_KEY_PATTERN, key) is None:
207207
msg = (
208-
"Pipeline input key name {} must be composed letters, numbers, and underscores with optional "
209-
"split by dots."
208+
"Pipeline input key name {} must be composed letters, numbers, and underscores with optional split by dots."
210209
)
211210
raise ValidationException(
212211
message=msg.format(key),
@@ -262,7 +261,6 @@ def to_rest_dataset_literal_inputs(
262261
uri=input_value.path,
263262
mode=(INPUT_MOUNT_MAPPING_TO_REST[input_value.mode.lower()] if input_value.mode else None),
264263
)
265-
266264
else:
267265
msg = f"Job input type {input_value.type} is not supported as job input."
268266
raise ValidationException(
@@ -415,7 +413,7 @@ def from_rest_data_outputs(outputs: Dict[str, RestJobOutput]) -> Dict[str, Outpu
415413
path_on_compute=sourcePathOnCompute,
416414
description=output_value.description,
417415
name=output_value.asset_name,
418-
version=output_value.asset_version,
416+
version=(output_value.asset_version if hasattr(output_value, "asset_version") else None),
419417
)
420418
else:
421419
msg = "unsupported JobOutput type: {}".format(output_value.job_output_type)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/finetuning/custom_model_finetuning_job.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import Any, Dict
88

9-
from azure.ai.ml._restclient.v2024_01_01_preview.models import (
9+
from azure.ai.ml._restclient.v2024_10_01_preview.models import (
1010
ModelProvider as RestModelProvider,
1111
CustomModelFineTuning as RestCustomModelFineTuningVertical,
1212
FineTuningJob as RestFineTuningJob,
@@ -16,6 +16,8 @@
1616
from_rest_data_outputs,
1717
to_rest_data_outputs,
1818
)
19+
from azure.ai.ml.entities._job.job_resources import JobResources
20+
from azure.ai.ml.entities._job.queue_settings import QueueSettings
1921
from azure.ai.ml.entities._inputs_outputs import Input
2022
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
2123
from azure.ai.ml.entities._job.finetuning.finetuning_vertical import FineTuningVertical
@@ -90,9 +92,14 @@ def _to_rest_object(self) -> "RestFineTuningJob":
9092
services=self.services,
9193
tags=self.tags,
9294
properties=self.properties,
95+
compute_id=self.compute,
9396
fine_tuning_details=custom_finetuning_vertical,
9497
outputs=to_rest_data_outputs(self.outputs),
9598
)
99+
if self.resources:
100+
finetuning_job.resources = self.resources._to_rest_object()
101+
if self.queue_settings:
102+
finetuning_job.queue_settings = self.queue_settings._to_rest_object()
96103

97104
result = RestJobBase(properties=finetuning_job)
98105
result.name = self.name
@@ -168,9 +175,15 @@ def _from_rest_object(cls, obj: RestJobBase) -> "CustomModelFineTuningJob":
168175
"status": properties.status,
169176
"creation_context": obj.system_data,
170177
"display_name": properties.display_name,
178+
"compute": properties.compute_id,
171179
"outputs": from_rest_data_outputs(properties.outputs),
172180
}
173181

182+
if properties.resources:
183+
job_args_dict["resources"] = JobResources._from_rest_object(properties.resources)
184+
if properties.queue_settings:
185+
job_args_dict["queue_settings"] = QueueSettings._from_rest_object(properties.queue_settings)
186+
174187
custom_model_finetuning_job = cls(
175188
task=finetuning_details.task_type,
176189
model=finetuning_details.model,

0 commit comments

Comments
 (0)