@@ -1461,14 +1461,16 @@ mod test_hostname_override_source {
14611461
14621462// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^=
14631463
1464- /// NvidiaRuntimeSettings contains the container runtime settings for Nvidia gpu.
1464+ /// NvidiaDevicePluginSettings contains the device sharing and partitioning related settings for Nvidia gpu.
14651465#[ model( impl_default = true ) ]
14661466pub struct NvidiaDevicePluginSettings {
14671467 pass_device_specs : bool ,
14681468 device_id_strategy : NvidiaDeviceIdStrategy ,
14691469 device_list_strategy : NvidiaDeviceListStrategy ,
14701470 device_sharing_strategy : NvidiaDeviceSharingStrategy ,
14711471 time_slicing : NvidiaTimeSlicingSettings ,
1472+ device_partitioning_strategy : NvidiaDevicePartitioningStrategy ,
1473+ mig : NvidiaMigSettings ,
14721474}
14731475
14741476#[ derive( Debug , Clone , Eq , PartialEq , Hash , Serialize , Deserialize ) ]
@@ -1499,10 +1501,123 @@ pub struct NvidiaTimeSlicingSettings {
14991501 fail_requests_greater_than_one : bool ,
15001502}
15011503
1504+ #[ derive( Debug , Clone , PartialEq , Serialize , Deserialize , Default ) ]
1505+ #[ serde( rename_all = "lowercase" ) ]
1506+ pub enum NvidiaDevicePartitioningStrategy {
1507+ #[ default]
1508+ None ,
1509+ MIG ,
1510+ }
1511+
1512+ #[ model( impl_default = true ) ]
1513+ pub struct NvidiaMigSettings {
1514+ profile : HashMap < NvidiaGpuModel , MigProfile > ,
1515+ }
1516+
1517+ #[ derive( Debug , Clone , Eq , PartialEq , Hash ) ]
1518+ pub struct NvidiaGpuModel {
1519+ inner : String ,
1520+ }
1521+
1522+ lazy_static ! {
1523+ pub ( crate ) static ref NVIDIAGPU_NAME : Regex = Regex :: new( r"^([a-z])(\d+)\.(\d+)gb$" ) . unwrap( ) ;
1524+ }
1525+
1526+ impl TryFrom < & str > for NvidiaGpuModel {
1527+ type Error = error:: Error ;
1528+
1529+ fn try_from ( input : & str ) -> Result < Self , Self :: Error > {
1530+ ensure ! (
1531+ NVIDIAGPU_NAME . is_match( input) ,
1532+ error:: PatternSnafu {
1533+ thing: "NVIDIA GPU Model" ,
1534+ pattern: NVIDIAGPU_NAME . clone( ) ,
1535+ input
1536+ }
1537+ ) ;
1538+
1539+ Ok ( NvidiaGpuModel {
1540+ inner : input. to_string ( ) ,
1541+ } )
1542+ }
1543+ }
1544+
1545+ string_impls_for ! ( NvidiaGpuModel , "NvidiaGpuModel" ) ;
1546+
1547+ #[ derive( Debug , Clone , Eq , PartialEq , Hash ) ]
1548+ pub struct MigProfile {
1549+ inner : String ,
1550+ }
1551+
1552+ lazy_static ! {
1553+ pub ( crate ) static ref MIGPROFILE_NAME : Regex = Regex :: new( r"^[0-9]g\.\d+gb$" ) . unwrap( ) ;
1554+ }
1555+
1556+ impl TryFrom < & str > for MigProfile {
1557+ type Error = error:: Error ;
1558+
1559+ fn try_from ( input : & str ) -> Result < Self , Self :: Error > {
1560+ let slice_format = matches ! ( input, "1" | "2" | "3" | "4" | "7" ) ;
1561+
1562+ ensure ! (
1563+ slice_format | MIGPROFILE_NAME . is_match( input) ,
1564+ error:: PatternSnafu {
1565+ thing: "MIG Profile" ,
1566+ pattern: MIGPROFILE_NAME . clone( ) ,
1567+ input
1568+ }
1569+ ) ;
1570+
1571+ Ok ( MigProfile {
1572+ inner : input. to_string ( ) ,
1573+ } )
1574+ }
1575+ }
1576+
1577+ string_impls_for ! ( MigProfile , "MigProfile" ) ;
1578+
15021579#[ cfg( test) ]
1503- mod tests {
1580+ mod test_nvidia_device_plugins {
15041581 use super :: * ;
15051582
1583+ #[ test]
1584+ fn valid_gpu_model ( ) {
1585+ for ok in & [ "a100.40gb" , "a100.80gb" , "h100.80gb" , "h100.141gb" ] {
1586+ assert ! ( NvidiaGpuModel :: try_from( * ok) . is_ok( ) ) ;
1587+ }
1588+ }
1589+
1590+ #[ test]
1591+ fn invalid_gpu_model ( ) {
1592+ assert ! ( NvidiaGpuModel :: try_from( "invalid" ) . is_err( ) ) ;
1593+ assert ! ( NvidiaGpuModel :: try_from( "1000" ) . is_err( ) ) ;
1594+ assert ! ( NvidiaGpuModel :: try_from( "A100.40GB" ) . is_err( ) ) ;
1595+ assert ! ( NvidiaGpuModel :: try_from( "a100.40" ) . is_err( ) ) ;
1596+ }
1597+
1598+ #[ test]
1599+ fn valid_mig_profile ( ) {
1600+ for ok in & [
1601+ "1g.5gb" , "2g.10gb" , "3g.20gb" , "7g.40gb" , "1g.10gb" , "1g.20gb" , "2g.20gb" , "3g.40gb" ,
1602+ "7g.80gb" , "1g.18gb" , "1g.35gb" , "2g.35gb" , "3g.71gb" , "7g.141gb" , "1" , "2" , "3" , "4" ,
1603+ "7" ,
1604+ ] {
1605+ assert ! ( MigProfile :: try_from( * ok) . is_ok( ) ) ;
1606+ }
1607+ }
1608+
1609+ #[ test]
1610+ fn invalid_mig_profile ( ) {
1611+ assert ! ( MigProfile :: try_from( "invalid" ) . is_err( ) ) ;
1612+ assert ! ( MigProfile :: try_from( "1000" ) . is_err( ) ) ;
1613+ assert ! ( MigProfile :: try_from( "5" ) . is_err( ) ) ;
1614+ assert ! ( MigProfile :: try_from( "10g.100GB" ) . is_err( ) ) ;
1615+ assert ! ( MigProfile :: try_from( "1g.10GB" ) . is_err( ) ) ;
1616+ assert ! ( MigProfile :: try_from( "1g10gb" ) . is_err( ) ) ;
1617+ assert ! ( MigProfile :: try_from( "g.10gb" ) . is_err( ) ) ;
1618+ assert ! ( MigProfile :: try_from( "1g.gb" ) . is_err( ) ) ;
1619+ }
1620+
15061621 #[ test]
15071622 fn test_serde_nvidia_device_plugins ( ) {
15081623 let test_json = r#"{"pass-device-specs":false,"device-id-strategy":"uuid","device-list-strategy":"envvar"}"# ;
@@ -1515,7 +1630,9 @@ mod tests {
15151630 device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
15161631 device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
15171632 device_sharing_strategy: None ,
1518- time_slicing: None
1633+ time_slicing: None ,
1634+ device_partitioning_strategy: None ,
1635+ mig: None
15191636 }
15201637 ) ;
15211638 let results = serde_json:: to_string ( & nvidia_device_plugins) . unwrap ( ) ;
@@ -1534,7 +1651,9 @@ mod tests {
15341651 device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
15351652 device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
15361653 device_sharing_strategy: Some ( NvidiaDeviceSharingStrategy :: TimeSlicing ) ,
1537- time_slicing: None
1654+ time_slicing: None ,
1655+ device_partitioning_strategy: None ,
1656+ mig: None
15381657 }
15391658 ) ;
15401659
@@ -1548,4 +1667,53 @@ mod tests {
15481667 let result: Result < NvidiaDevicePluginSettings , _ > = serde_json:: from_str ( test_json) ;
15491668 assert ! ( result. is_err( ) , "The JSON should not be parsed successfully as it contains an invalid value for 'replicas'." ) ;
15501669 }
1670+
1671+ #[ test]
1672+ fn test_serde_nvidia_device_plugins_with_mig ( ) {
1673+ let test_json = r#"{"pass-device-specs":false,"device-id-strategy":"uuid","device-list-strategy":"envvar","device-partitioning-strategy":"mig"}"# ;
1674+ let nvidia_device_plugins: NvidiaDevicePluginSettings =
1675+ serde_json:: from_str ( test_json) . unwrap ( ) ;
1676+ assert_eq ! (
1677+ nvidia_device_plugins,
1678+ NvidiaDevicePluginSettings {
1679+ pass_device_specs: Some ( false ) ,
1680+ device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
1681+ device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
1682+ device_sharing_strategy: None ,
1683+ time_slicing: None ,
1684+ device_partitioning_strategy: Some ( NvidiaDevicePartitioningStrategy :: MIG ) ,
1685+ mig: None
1686+ }
1687+ ) ;
1688+
1689+ let results = serde_json:: to_string ( & nvidia_device_plugins) . unwrap ( ) ;
1690+ assert_eq ! ( results, test_json) ;
1691+ }
1692+
1693+ #[ test]
1694+ fn test_serde_nvidia_device_plugins_with_mig_profile ( ) {
1695+ let test_json = r#"{"pass-device-specs":false,"device-id-strategy":"uuid","device-list-strategy":"envvar","device-partitioning-strategy":"mig","mig":{"profile":{"a100.40gb":"1g.5gb"}}}"# ;
1696+ let nvidia_device_plugins: NvidiaDevicePluginSettings =
1697+ serde_json:: from_str ( test_json) . unwrap ( ) ;
1698+ assert_eq ! (
1699+ nvidia_device_plugins,
1700+ NvidiaDevicePluginSettings {
1701+ pass_device_specs: Some ( false ) ,
1702+ device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
1703+ device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
1704+ device_sharing_strategy: None ,
1705+ time_slicing: None ,
1706+ device_partitioning_strategy: Some ( NvidiaDevicePartitioningStrategy :: MIG ) ,
1707+ mig: Some ( NvidiaMigSettings {
1708+ profile: Some ( HashMap :: from( [ (
1709+ NvidiaGpuModel :: try_from( "a100.40gb" ) . unwrap( ) ,
1710+ MigProfile :: try_from( "1g.5gb" ) . unwrap( )
1711+ ) ] ) )
1712+ } ) ,
1713+ }
1714+ ) ;
1715+
1716+ let results = serde_json:: to_string ( & nvidia_device_plugins) . unwrap ( ) ;
1717+ assert_eq ! ( results, test_json) ;
1718+ }
15511719}
0 commit comments