@@ -1461,14 +1461,16 @@ mod test_hostname_override_source {
14611461
14621462// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^=
14631463
1464- /// NvidiaRuntimeSettings contains the container runtime settings for Nvidia gpu.
1464+ /// NvidiaDevicePluginSettings contains the device sharing and partitioning related settings for Nvidia gpu.
14651465#[ model( impl_default = true ) ]
14661466pub struct NvidiaDevicePluginSettings {
14671467 pass_device_specs : bool ,
14681468 device_id_strategy : NvidiaDeviceIdStrategy ,
14691469 device_list_strategy : NvidiaDeviceListStrategy ,
14701470 device_sharing_strategy : NvidiaDeviceSharingStrategy ,
14711471 time_slicing : NvidiaTimeSlicingSettings ,
1472+ device_partitioning_strategy : NvidiaDevicePartitioningStrategy ,
1473+ mig : NvidiaMIGSettings ,
14721474}
14731475
14741476#[ derive( Debug , Clone , Eq , PartialEq , Hash , Serialize , Deserialize ) ]
@@ -1499,10 +1501,115 @@ pub struct NvidiaTimeSlicingSettings {
14991501 fail_requests_greater_than_one : bool ,
15001502}
15011503
1504+ #[ derive( Debug , Clone , PartialEq , Serialize , Deserialize , Default ) ]
1505+ #[ serde( rename_all = "lowercase" ) ]
1506+ pub enum NvidiaDevicePartitioningStrategy {
1507+ #[ default]
1508+ None ,
1509+ MIG ,
1510+ }
1511+
1512+ #[ model( impl_default = true ) ]
1513+ pub struct NvidiaMIGSettings {
1514+ profile : HashMap < NvidiaGpuModel , MigProfile > ,
1515+ }
1516+
1517+ #[ derive( Debug , Clone , Eq , PartialEq , Hash ) ]
1518+ pub struct NvidiaGpuModel {
1519+ inner : String ,
1520+ }
1521+
1522+ lazy_static ! {
1523+ pub ( crate ) static ref NVIDIAGPU_NAME : Regex = Regex :: new( r"^([a-z])(\d+)\.(\d+)gb$" ) . unwrap( ) ;
1524+ }
1525+
1526+ impl TryFrom < & str > for NvidiaGpuModel {
1527+ type Error = error:: Error ;
1528+
1529+ fn try_from ( input : & str ) -> Result < Self , Self :: Error > {
1530+ ensure ! (
1531+ NVIDIAGPU_NAME . is_match( input) ,
1532+ error:: PatternSnafu {
1533+ thing: "NVIDIA GPU Model" ,
1534+ pattern: NVIDIAGPU_NAME . clone( ) ,
1535+ input
1536+ }
1537+ ) ;
1538+
1539+ Ok ( NvidiaGpuModel {
1540+ inner : input. to_string ( ) ,
1541+ } )
1542+ }
1543+ }
1544+
1545+ string_impls_for ! ( NvidiaGpuModel , "NvidiaGpuModel" ) ;
1546+
1547+ #[ derive( Debug , Clone , Eq , PartialEq , Hash ) ]
1548+ pub struct MigProfile {
1549+ inner : String ,
1550+ }
1551+
1552+ lazy_static ! {
1553+ pub ( crate ) static ref MIGPROFILE_NAME : Regex = Regex :: new( r"^[0-9]g\.\d+gb$" ) . unwrap( ) ;
1554+ }
1555+
1556+ impl TryFrom < & str > for MigProfile {
1557+ type Error = error:: Error ;
1558+
1559+ fn try_from ( input : & str ) -> Result < Self , Self :: Error > {
1560+ let slice_format = matches ! ( input, "1" | "2" | "3" | "4" | "7" ) ;
1561+
1562+ ensure ! (
1563+ slice_format | MIGPROFILE_NAME . is_match( input) ,
1564+ error:: PatternSnafu {
1565+ thing: "MIG Profile" ,
1566+ pattern: MIGPROFILE_NAME . clone( ) ,
1567+ input
1568+ }
1569+ ) ;
1570+
1571+ Ok ( MigProfile {
1572+ inner : input. to_string ( ) ,
1573+ } )
1574+ }
1575+ }
1576+
1577+ string_impls_for ! ( MigProfile , "MigProfile" ) ;
1578+
15021579#[ cfg( test) ]
1503- mod tests {
1580+ mod test_nvidia_device_plugins {
15041581 use super :: * ;
15051582
1583+ #[ test]
1584+ fn valid_gpu_model ( ) {
1585+ for ok in & [ "a100.40gb" , "a100.80gb" , "h100.80gb" , "h100.141gb" ] {
1586+ assert ! ( NvidiaGpuModel :: try_from( * ok) . is_ok( ) ) ;
1587+ }
1588+ }
1589+
1590+ #[ test]
1591+ fn invalid_gpu_model ( ) {
1592+ assert ! ( NvidiaGpuModel :: try_from( "invalid" ) . is_err( ) ) ;
1593+ assert ! ( NvidiaGpuModel :: try_from( "1000" ) . is_err( ) ) ;
1594+ }
1595+
1596+ #[ test]
1597+ fn valid_mig_profile ( ) {
1598+ for ok in & [
1599+ "1g.5gb" , "2g.10gb" , "3g.20gb" , "7g.40gb" , "1g.10gb" , "1g.20gb" , "2g.20gb" , "3g.40gb" ,
1600+ "7g.80gb" , "1g.18gb" , "1g.35gb" , "2g.35gb" , "3g.71gb" , "7g.141gb" , "1" , "2" , "3" , "4" ,
1601+ "7" ,
1602+ ] {
1603+ assert ! ( MigProfile :: try_from( * ok) . is_ok( ) ) ;
1604+ }
1605+ }
1606+
1607+ #[ test]
1608+ fn invalid_mig_profile ( ) {
1609+ assert ! ( MigProfile :: try_from( "invalid" ) . is_err( ) ) ;
1610+ assert ! ( MigProfile :: try_from( "1000" ) . is_err( ) ) ;
1611+ }
1612+
15061613 #[ test]
15071614 fn test_serde_nvidia_device_plugins ( ) {
15081615 let test_json = r#"{"pass-device-specs":false,"device-id-strategy":"uuid","device-list-strategy":"envvar"}"# ;
@@ -1515,7 +1622,9 @@ mod tests {
15151622 device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
15161623 device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
15171624 device_sharing_strategy: None ,
1518- time_slicing: None
1625+ time_slicing: None ,
1626+ device_partitioning_strategy: None ,
1627+ mig: None
15191628 }
15201629 ) ;
15211630 let results = serde_json:: to_string ( & nvidia_device_plugins) . unwrap ( ) ;
@@ -1534,7 +1643,9 @@ mod tests {
15341643 device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
15351644 device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
15361645 device_sharing_strategy: Some ( NvidiaDeviceSharingStrategy :: TimeSlicing ) ,
1537- time_slicing: None
1646+ time_slicing: None ,
1647+ device_partitioning_strategy: None ,
1648+ mig: None
15381649 }
15391650 ) ;
15401651
@@ -1548,4 +1659,53 @@ mod tests {
15481659 let result: Result < NvidiaDevicePluginSettings , _ > = serde_json:: from_str ( test_json) ;
15491660 assert ! ( result. is_err( ) , "The JSON should not be parsed successfully as it contains an invalid value for 'replicas'." ) ;
15501661 }
1662+
1663+ #[ test]
1664+ fn test_serde_nvidia_device_plugins_with_mig ( ) {
1665+ let test_json = r#"{"pass-device-specs":false,"device-id-strategy":"uuid","device-list-strategy":"envvar","device-partitioning-strategy":"mig"}"# ;
1666+ let nvidia_device_plugins: NvidiaDevicePluginSettings =
1667+ serde_json:: from_str ( test_json) . unwrap ( ) ;
1668+ assert_eq ! (
1669+ nvidia_device_plugins,
1670+ NvidiaDevicePluginSettings {
1671+ pass_device_specs: Some ( false ) ,
1672+ device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
1673+ device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
1674+ device_sharing_strategy: None ,
1675+ time_slicing: None ,
1676+ device_partitioning_strategy: Some ( NvidiaDevicePartitioningStrategy :: MIG ) ,
1677+ mig: None
1678+ }
1679+ ) ;
1680+
1681+ let results = serde_json:: to_string ( & nvidia_device_plugins) . unwrap ( ) ;
1682+ assert_eq ! ( results, test_json) ;
1683+ }
1684+
1685+ #[ test]
1686+ fn test_serde_nvidia_device_plugins_with_mig_profile ( ) {
1687+ let test_json = r#"{"pass-device-specs":false,"device-id-strategy":"uuid","device-list-strategy":"envvar","device-partitioning-strategy":"mig","mig":{"profile":{"a100.40gb":"1g.5gb"}}}"# ;
1688+ let nvidia_device_plugins: NvidiaDevicePluginSettings =
1689+ serde_json:: from_str ( test_json) . unwrap ( ) ;
1690+ assert_eq ! (
1691+ nvidia_device_plugins,
1692+ NvidiaDevicePluginSettings {
1693+ pass_device_specs: Some ( false ) ,
1694+ device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
1695+ device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
1696+ device_sharing_strategy: None ,
1697+ time_slicing: None ,
1698+ device_partitioning_strategy: Some ( NvidiaDevicePartitioningStrategy :: MIG ) ,
1699+ mig: Some ( NvidiaMIGSettings {
1700+ profile: Some ( HashMap :: from( [ (
1701+ NvidiaGpuModel :: try_from( "a100.40gb" ) . unwrap( ) ,
1702+ MigProfile :: try_from( "1g.5gb" ) . unwrap( )
1703+ ) ] ) )
1704+ } ) ,
1705+ }
1706+ ) ;
1707+
1708+ let results = serde_json:: to_string ( & nvidia_device_plugins) . unwrap ( ) ;
1709+ assert_eq ! ( results, test_json) ;
1710+ }
15511711}
0 commit comments