@@ -117,29 +117,37 @@ def _get_instance_type_parameters(): # noqa: C901
117117 for page in paginator .paginate (LocationType = "availability-zone" ):
118118 for instance_type in page ["InstanceTypeOfferings" ]:
119119 # Check if instance type ends with '.xlarge'
120- if instance_type ["InstanceType" ].endswith (".xlarge" ) and not any (
121- instance_type [ "InstanceType" ]. startswith ( prefix ) for prefix in excluded_instance_type_prefixes
120+ if instance_type ["InstanceType" ].endswith (".xlarge" ) and _is_current_instance_type_generation (
121+ excluded_instance_type_prefixes , instance_type
122122 ):
123123 xlarge_instances .append (instance_type ["InstanceType" ])
124- if instance_type_availability_zones .get (instance_type ["InstanceType" ]):
125- instance_type_availability_zones [instance_type ["InstanceType" ]].append (
126- instance_type ["Location" ]
127- )
128- else :
129- instance_type_availability_zones [instance_type ["InstanceType" ]] = [
130- instance_type ["Location" ]
131- ]
124+ if instance_type_availability_zones .get (instance_type ["InstanceType" ]):
125+ instance_type_availability_zones [instance_type ["InstanceType" ]].append (
126+ instance_type ["Location" ]
127+ )
128+ else :
129+ instance_type_availability_zones [instance_type ["InstanceType" ]] = [instance_type ["Location" ]]
132130
133131 xlarge_instances = list (set (xlarge_instances )) # Remove redundancy.
134132 gpu_instances = []
135133 paginator = ec2_client .get_paginator ("describe_instance_types" )
136134 for page in paginator .paginate (InstanceTypes = xlarge_instances ):
137135 for instance_type in page ["InstanceTypes" ]:
138- if instance_type .get ("GpuInfo" ):
139- if (
140- instance_type .get ("GpuInfo" ).get ("Gpus" )
141- and instance_type .get ("GpuInfo" ).get ("Gpus" )[0 ].get ("Manufacturer" ) == "NVIDIA"
142- ):
136+ if _is_nvidia_gpu_instance_type (instance_type ):
137+ gpu_instances .append (instance_type ["InstanceType" ])
138+
139+ for page in paginator .paginate ():
140+ for instance_type in page ["InstanceTypes" ]:
141+ if (
142+ _is_nvidia_gpu_instance_type (instance_type )
143+ and instance_type .get ("GpuInfo" ).get ("Gpus" )[0 ].get ("Count" ) >= 4
144+ and _is_current_instance_type_generation (excluded_instance_type_prefixes , instance_type )
145+ ):
146+ # Find instance types with 4 or more GPUs. Number of GPUs can change test behavior.
147+ # For example, it takes longer for DCGM health check to diagnose multiple GPUs.
148+ instance_size = instance_type ["InstanceType" ].split ("." )[1 ][: - len ("xlarge" )]
149+ if instance_size and int (instance_size ) < 20 :
150+ # Avoid using very expensive instance types
143151 gpu_instances .append (instance_type ["InstanceType" ])
144152
145153 xlarge_instances .sort ()
@@ -154,7 +162,7 @@ def _get_instance_type_parameters(): # noqa: C901
154162 )
155163 for index in range (len (gpu_instances )):
156164 instance_type = gpu_instances [(today_number + index ) % len (gpu_instances )]
157- result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } " ] = instance_type [: - len ( ".xlarge" )]
165+ result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } " ] = instance_type
158166 availability_zones = instance_type_availability_zones [instance_type ]
159167 result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } _AZ" ] = (
160168 availability_zones [0 ] if len (availability_zones ) <= 2 else region
@@ -165,11 +173,23 @@ def _get_instance_type_parameters(): # noqa: C901
165173 result [f"{ region_jinja } _INSTANCE_TYPE_{ index } " ] = "c5"
166174 result [f"{ region_jinja } _INSTANCE_TYPE_{ index } _AZ" ] = region
167175 for index in range (10 ):
168- result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } " ] = "g4dn"
176+ result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } " ] = "g4dn.xlarge "
169177 result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } _AZ" ] = region
170178 return result
171179
172180
181+ def _is_nvidia_gpu_instance_type (instance_type ):
182+ return (
183+ instance_type .get ("GpuInfo" )
184+ and instance_type .get ("GpuInfo" ).get ("Gpus" )
185+ and instance_type .get ("GpuInfo" ).get ("Gpus" )[0 ].get ("Manufacturer" ) == "NVIDIA"
186+ )
187+
188+
189+ def _is_current_instance_type_generation (excluded_instance_type_prefixes , instance_type ):
190+ return not any (instance_type ["InstanceType" ].startswith (prefix ) for prefix in excluded_instance_type_prefixes )
191+
192+
173193def _get_available_amis_oss (architecture , args = None , config = None ):
174194 """
175195 Gets available AMIs for given architecture from input.
@@ -306,10 +326,16 @@ def _check_or_create_capacity_reservations(config_file, os_parameters, instance_
306326
307327def _resolve_instance_type_and_os (instance_type , instance_type_parameters , os , os_parameters ):
308328 if "INSTANCE_TYPE" in instance_type :
329+ # The value of the Jinja INSTANCE_TYPE variable can contain a size or not, e.g. trn1.32xlarge vs trn1.
330+ # When Jinja name is like INSTANCE_TYPE_0_xlarge, the value doesn't contain size
331+ # When Jinja name is like INSTANCE_TYPE_0, the value contains size.
332+ # In other words, the size should appear once either in name or value. The code below handles this logic.
309333 instance_type_size = instance_type .split ("_" )[- 1 ]
310- instance_type = (
311- instance_type_parameters .get (instance_type [: - len (instance_type_size ) - 1 ]) + "." + instance_type_size
312- )
334+ instance_type_family = instance_type_parameters .get (instance_type [: - len (instance_type_size ) - 1 ])
335+ if instance_type_family :
336+ instance_type = instance_type_family + "." + instance_type_size
337+ else :
338+ instance_type = instance_type_parameters .get (instance_type )
313339 else :
314340 instance_type = instance_type .replace ("_" , "." )
315341 os_platform = "Linux/UNIX"
0 commit comments