55# pylint: disable=protected-access
66
77import logging
8+ import warnings
89from os import PathLike
910from pathlib import Path
1011from typing import Any , Dict , Optional , Union
2728
2829from .code_configuration import CodeConfiguration
2930from .deployment import Deployment
31+ from .model_batch_deployment_settings import ModelBatchDeploymentSettings as BatchDeploymentSettings
3032
3133module_logger = logging .getLogger (__name__ )
3234
35+ SETTINGS_ATTRIBUTES = [
36+ "output_action" ,
37+ "output_file_name" ,
38+ "error_threshold" ,
39+ "retry_settings" ,
40+ "logging_level" ,
41+ "mini_batch_size" ,
42+ "max_concurrency_per_instance" ,
43+ "environment_variables" ,
44+ ]
45+
3346
3447class BatchDeployment (Deployment ):
3548 """Batch endpoint deployment entity.
3649
50+ **Warning** This class should not be used directly.
51+ Please use one of the child implementations, :class:`~azure.ai.ml.entities.ModelBatchDeployment` or
52+ :class:`azure.ai.ml.entities.PipelineComponentBatchDeployment`.
53+
3754 :param name: the name of the batch deployment
3855 :type name: str
3956 :param description: Description of the resource.
@@ -112,34 +129,61 @@ def __init__(
112129 instance_count : Optional [int ] = None , # promoted property from resources.instance_count
113130 ** kwargs : Any ,
114131 ) -> None :
132+ _type = kwargs .pop ("_type" , None )
133+
134+ # Suppresses deprecation warning when object is created from REST responses
135+ # This is needed to avoid false deprecation warning on model batch deployment
136+ if _type is None and not kwargs .pop ("_from_rest" , False ):
137+ warnings .warn (
138+ "This class is intended as a base class and it's direct usage is deprecated. "
139+ "Use one of the concrete implementations instead:\n "
140+ "* ModelBatchDeployment - For model-based batch deployments\n "
141+ "* PipelineComponentBatchDeployment - For pipeline component-based batch deployments"
142+ )
115143 self ._provisioning_state : Optional [str ] = kwargs .pop ("provisioning_state" , None )
116144
145+ settings = kwargs .pop ("settings" , None )
117146 super (BatchDeployment , self ).__init__ (
118147 name = name ,
148+ type = _type ,
119149 endpoint_name = endpoint_name ,
120150 properties = properties ,
121151 tags = tags ,
122152 description = description ,
123153 model = model ,
124154 code_configuration = code_configuration ,
125155 environment = environment ,
126- environment_variables = environment_variables ,
156+ environment_variables = environment_variables , # needed, otherwise Deployment.__init__() will set it to {}
127157 code_path = code_path ,
128158 scoring_script = scoring_script ,
129159 ** kwargs ,
130160 )
131161
132162 self .compute = compute
133163 self .resources = resources
134- self .output_action = output_action
135- self .output_file_name = output_file_name
136- self .error_threshold = error_threshold
137- self .retry_settings = retry_settings
138- self .logging_level = logging_level
139- self .mini_batch_size = mini_batch_size
140- self .max_concurrency_per_instance = max_concurrency_per_instance
141-
142- if self .resources and instance_count :
164+
165+ self ._settings = (
166+ settings
167+ if settings
168+ else BatchDeploymentSettings (
169+ mini_batch_size = mini_batch_size ,
170+ instance_count = instance_count ,
171+ max_concurrency_per_instance = max_concurrency_per_instance ,
172+ output_action = output_action ,
173+ output_file_name = output_file_name ,
174+ retry_settings = retry_settings ,
175+ environment_variables = environment_variables ,
176+ error_threshold = error_threshold ,
177+ logging_level = logging_level ,
178+ )
179+ )
180+
181+ self ._setup_instance_count ()
182+
183+ def _setup_instance_count (
184+ self ,
185+ ) -> None : # No need to check instance_count here as it's already set in self._settings during initialization
186+ if self .resources and self ._settings .instance_count :
143187 msg = "Can't set instance_count when resources is provided."
144188 raise ValidationException (
145189 message = msg ,
@@ -149,8 +193,26 @@ def __init__(
149193 error_type = ValidationErrorType .INVALID_VALUE ,
150194 )
151195
152- if not self .resources and instance_count :
153- self .resources = ResourceConfiguration (instance_count = instance_count )
196+ if not self .resources and self ._settings .instance_count :
197+ self .resources = ResourceConfiguration (instance_count = self ._settings .instance_count )
198+
199+ def __getattr__ (self , name : str ) -> Optional [Any ]:
200+ # Support backwards compatibility with old BatchDeployment properties.
201+ if name in SETTINGS_ATTRIBUTES :
202+ try :
203+ return getattr (self ._settings , name )
204+ except AttributeError :
205+ pass
206+ return super ().__getattribute__ (name )
207+
208+ def __setattr__ (self , name , value ):
209+ # Support backwards compatibility with old BatchDeployment properties.
210+ if name in SETTINGS_ATTRIBUTES :
211+ try :
212+ setattr (self ._settings , name , value )
213+ except AttributeError :
214+ pass
215+ super ().__setattr__ (name , value )
154216
155217 @property
156218 def instance_count (self ) -> Optional [int ]:
@@ -195,7 +257,7 @@ def _yaml_output_action_to_rest_output_action(cls, yaml_output_action: Any) -> s
195257 return output_switcher .get (yaml_output_action , yaml_output_action )
196258
197259 # pylint: disable=arguments-differ
198- def _to_rest_object (self , location : str ) -> BatchDeploymentData : # type: ignore
260+ def _to_rest_object (self , location : str ) -> BatchDeploymentData : # type: ignore[override]
199261 self ._validate ()
200262 code_config = (
201263 RestCodeConfiguration (
@@ -209,42 +271,28 @@ def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore
209271 environment = self .environment
210272
211273 batch_deployment : RestBatchDeployment = None
212- if isinstance (self .output_action , str ):
213- batch_deployment = RestBatchDeployment (
214- compute = self .compute ,
215- description = self .description ,
216- resources = self .resources ._to_rest_object () if self .resources else None ,
217- code_configuration = code_config ,
218- environment_id = environment ,
219- model = model ,
220- output_file_name = self .output_file_name ,
221- output_action = BatchDeployment ._yaml_output_action_to_rest_output_action (self .output_action ),
222- error_threshold = self .error_threshold ,
223- retry_settings = self .retry_settings ._to_rest_object () if self .retry_settings else None ,
224- logging_level = self .logging_level ,
225- mini_batch_size = self .mini_batch_size ,
226- max_concurrency_per_instance = self .max_concurrency_per_instance ,
227- environment_variables = self .environment_variables ,
228- properties = self .properties ,
229- )
230- else :
231- batch_deployment = RestBatchDeployment (
232- compute = self .compute ,
233- description = self .description ,
234- resources = self .resources ._to_rest_object () if self .resources else None ,
235- code_configuration = code_config ,
236- environment_id = environment ,
237- model = model ,
238- output_file_name = self .output_file_name ,
239- output_action = None ,
240- error_threshold = self .error_threshold ,
241- retry_settings = self .retry_settings ._to_rest_object () if self .retry_settings else None ,
242- logging_level = self .logging_level ,
243- mini_batch_size = self .mini_batch_size ,
244- max_concurrency_per_instance = self .max_concurrency_per_instance ,
245- environment_variables = self .environment_variables ,
246- properties = self .properties ,
247- )
274+ # Create base RestBatchDeployment object with common properties
275+ batch_deployment = RestBatchDeployment (
276+ compute = self .compute ,
277+ description = self .description ,
278+ resources = self .resources ._to_rest_object () if self .resources else None ,
279+ code_configuration = code_config ,
280+ environment_id = environment ,
281+ model = model ,
282+ output_file_name = self .output_file_name ,
283+ output_action = (
284+ BatchDeployment ._yaml_output_action_to_rest_output_action (self .output_action )
285+ if isinstance (self .output_action , str )
286+ else None
287+ ),
288+ error_threshold = self .error_threshold ,
289+ retry_settings = self .retry_settings ._to_rest_object () if self .retry_settings else None ,
290+ logging_level = self .logging_level ,
291+ mini_batch_size = self .mini_batch_size ,
292+ max_concurrency_per_instance = self .max_concurrency_per_instance ,
293+ environment_variables = self .environment_variables ,
294+ properties = self .properties ,
295+ )
248296
249297 return BatchDeploymentData (location = location , properties = batch_deployment , tags = self .tags )
250298
@@ -306,6 +354,7 @@ def _from_rest_object( # pylint: disable=arguments-renamed
306354 properties = properties ,
307355 creation_context = SystemData ._from_rest_object (deployment .system_data ),
308356 provisioning_state = deployment .properties .provisioning_state ,
357+ _from_rest = True ,
309358 )
310359
311360 return deployment
0 commit comments