@@ -373,32 +373,33 @@ def build_dict(**kwargs):
373373 return {k : v for k , v in kwargs .items () if v is not None }
374374
375375 # Build resources
376- if self .instance_type is None :
377- requests_value = limits_value = {"nvidia.com/gpu" : "0" }
376+ if self .accelerator_partition_type :
377+ partition_resource_key = f"nvidia.com/{ self .accelerator_partition_type } "
378+ requests_value = build_dict (
379+ ** {partition_resource_key : str (self .accelerator_partition_count )} if self .accelerator_partition_count else {},
380+ vcpu = str (self .vcpu ) if self .vcpu else None ,
381+ memory = str (self .memory ) if self .memory else None ,
382+ ** {"vpc.amazonaws.com/efa" : "1" } if self .instance_type and "p4d" in self .instance_type else {}
383+ )
384+ limits_value = build_dict (
385+ ** {partition_resource_key : str (self .accelerator_partition_limit )} if self .accelerator_partition_limit else {},
386+ vcpu = str (self .vcpu_limit ) if self .vcpu_limit else None ,
387+ memory = str (self .memory_limit ) if self .memory_limit else None ,
388+ ** {"vpc.amazonaws.com/efa" : "1" } if self .instance_type and "p4d" in self .instance_type else {}
389+ )
378390 else :
379- if self .accelerator_partition_type :
380- partition_resource_key = f"nvidia.com/{ self .accelerator_partition_type } "
381- requests_value = build_dict (
382- ** {partition_resource_key : str (self .accelerator_partition_count )} if self .accelerator_partition_count else {},
383- vcpu = str (self .vcpu ) if self .vcpu else None ,
384- memory = str (self .memory ) if self .memory else None
385- )
386- limits_value = build_dict (
387- ** {partition_resource_key : str (self .accelerator_partition_limit )} if self .accelerator_partition_limit else {},
388- vcpu = str (self .vcpu_limit ) if self .vcpu_limit else None ,
389- memory = str (self .memory_limit ) if self .memory_limit else None
390- )
391- else :
392- requests_value = build_dict (
393- accelerators = str (self .accelerators ) if self .accelerators else None ,
394- vcpu = str (self .vcpu ) if self .vcpu else None ,
395- memory = str (self .memory ) if self .memory else None
396- )
397- limits_value = build_dict (
398- accelerators = str (self .accelerators_limit ) if self .accelerators_limit else None ,
399- vcpu = str (self .vcpu_limit ) if self .vcpu_limit else None ,
400- memory = str (self .memory_limit ) if self .memory_limit else None
401- )
391+ requests_value = build_dict (
392+ ** {"nvidia.com/gpu" : str (self .accelerators )} if self .accelerators else {},
393+ vcpu = str (self .vcpu ) if self .vcpu else None ,
394+ memory = str (self .memory ) if self .memory else None ,
395+ ** {"vpc.amazonaws.com/efa" : "1" } if self .instance_type and "p4d" in self .instance_type else {}
396+ )
397+ limits_value = build_dict (
398+ ** {"nvidia.com/gpu" : str (self .accelerators_limit )} if self .accelerators_limit else {},
399+ vcpu = str (self .vcpu_limit ) if self .vcpu_limit else None ,
400+ memory = str (self .memory_limit ) if self .memory_limit else None ,
401+ ** {"vpc.amazonaws.com/efa" : "1" } if self .instance_type and "p4d" in self .instance_type else {}
402+ )
402403
403404 # Build container
404405 container_kwargs = build_dict (
0 commit comments