@@ -1469,6 +1469,8 @@ pub struct NvidiaDevicePluginSettings {
14691469 device_list_strategy : NvidiaDeviceListStrategy ,
14701470 device_sharing_strategy : NvidiaDeviceSharingStrategy ,
14711471 time_slicing : NvidiaTimeSlicingSettings ,
1472+ device_partitioning_strategy : NvidiaDevicePartitioningStrategy ,
1473+ mig : NvidiaMIGSettings ,
14721474}
14731475
14741476#[ derive( Debug , Clone , Eq , PartialEq , Hash , Serialize , Deserialize ) ]
@@ -1515,7 +1517,9 @@ mod tests {
15151517 device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
15161518 device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
15171519 device_sharing_strategy: None ,
1518- time_slicing: None
1520+ time_slicing: None ,
1521+ device_partitioning_strategy: None ,
1522+ mig: None
15191523 }
15201524 ) ;
15211525 let results = serde_json:: to_string ( & nvidia_device_plugins) . unwrap ( ) ;
@@ -1534,7 +1538,9 @@ mod tests {
15341538 device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
15351539 device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
15361540 device_sharing_strategy: Some ( NvidiaDeviceSharingStrategy :: TimeSlicing ) ,
1537- time_slicing: None
1541+ time_slicing: None ,
1542+ device_partitioning_strategy: None ,
1543+ mig: None
15381544 }
15391545 ) ;
15401546
@@ -1549,3 +1555,179 @@ mod tests {
15491555 assert ! ( result. is_err( ) , "The JSON should not be parsed successfully as it contains an invalid value for 'replicas'." ) ;
15501556 }
15511557}
1558+
1559+ #[ derive( Debug , Clone , PartialEq , Serialize , Deserialize , Default ) ]
1560+ #[ serde( rename_all = "lowercase" ) ]
1561+ pub enum NvidiaDevicePartitioningStrategy {
1562+ #[ default]
1563+ None ,
1564+ MIG ,
1565+ }
1566+
1567+ #[ cfg( test) ]
1568+ mod test_valid_device_partitioning_strategy {
1569+ use super :: * ;
1570+
1571+ #[ test]
1572+ fn test_serde_nvidia_device_plugins_with_mig ( ) {
1573+ let test_json = r#"{"pass-device-specs":false,"device-id-strategy":"uuid","device-list-strategy":"envvar","device-partitioning-strategy":"mig"}"# ;
1574+ let nvidia_device_plugins: NvidiaDevicePluginSettings =
1575+ serde_json:: from_str ( test_json) . unwrap ( ) ;
1576+ assert_eq ! (
1577+ nvidia_device_plugins,
1578+ NvidiaDevicePluginSettings {
1579+ pass_device_specs: Some ( false ) ,
1580+ device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
1581+ device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
1582+ device_sharing_strategy: None ,
1583+ time_slicing: None ,
1584+ device_partitioning_strategy: Some ( NvidiaDevicePartitioningStrategy :: MIG ) ,
1585+ mig: None
1586+ }
1587+ ) ;
1588+
1589+ let results = serde_json:: to_string ( & nvidia_device_plugins) . unwrap ( ) ;
1590+ assert_eq ! ( results, test_json) ;
1591+ }
1592+ }
1593+
1594+ #[ derive( Debug , Clone , Eq , PartialEq , Hash ) ]
1595+ pub struct NvidiaGPUModel {
1596+ inner : String ,
1597+ }
1598+
1599+ lazy_static ! {
1600+ pub ( crate ) static ref NVIDIAGPU_NAME : Regex = Regex :: new( r"^([a-z])(\d+)\.(\d+)gb$" ) . unwrap( ) ;
1601+ }
1602+
1603+ impl TryFrom < & str > for NvidiaGPUModel {
1604+ type Error = error:: Error ;
1605+
1606+ fn try_from ( input : & str ) -> Result < Self , Self :: Error > {
1607+ ensure ! (
1608+ NVIDIAGPU_NAME . is_match( input) ,
1609+ error:: PatternSnafu {
1610+ thing: "NVIDIA GPU Model" ,
1611+ pattern: NVIDIAGPU_NAME . clone( ) ,
1612+ input
1613+ }
1614+ ) ;
1615+
1616+ Ok ( NvidiaGPUModel {
1617+ inner : input. to_string ( ) ,
1618+ } )
1619+ }
1620+ }
1621+
1622+ string_impls_for ! ( NvidiaGPUModel , "NvidiaGPUModel" ) ;
1623+
1624+ #[ cfg( test) ]
1625+ mod test_valid_gpu_model {
1626+ use super :: NvidiaGPUModel ;
1627+ use std:: convert:: TryFrom ;
1628+
1629+ #[ test]
1630+ fn valid_gpu_model ( ) {
1631+ for ok in & [ "a100.40gb" , "a100.80gb" , "h100.80gb" , "h100.141gb" ] {
1632+ assert ! ( NvidiaGPUModel :: try_from( * ok) . is_ok( ) ) ;
1633+ }
1634+ }
1635+
1636+ #[ test]
1637+ fn invalid_gpu_model ( ) {
1638+ assert ! ( NvidiaGPUModel :: try_from( "invalid" ) . is_err( ) ) ;
1639+ assert ! ( NvidiaGPUModel :: try_from( "1000" ) . is_err( ) ) ;
1640+ }
1641+ }
1642+
1643+ #[ derive( Debug , Clone , Eq , PartialEq , Hash ) ]
1644+ pub struct MIGProfile {
1645+ inner : String ,
1646+ }
1647+
1648+ lazy_static ! {
1649+ pub ( crate ) static ref MIGPROFILE_NAME : Regex = Regex :: new( r"^[0-9]g\.\d+gb$" ) . unwrap( ) ;
1650+ }
1651+
1652+ impl TryFrom < & str > for MIGProfile {
1653+ type Error = error:: Error ;
1654+
1655+ fn try_from ( input : & str ) -> Result < Self , Self :: Error > {
1656+ let slice_format = matches ! ( input, "1" | "2" | "3" | "4" | "7" ) ;
1657+
1658+ ensure ! (
1659+ slice_format | MIGPROFILE_NAME . is_match( input) ,
1660+ error:: PatternSnafu {
1661+ thing: "MIG Profile" ,
1662+ pattern: MIGPROFILE_NAME . clone( ) ,
1663+ input
1664+ }
1665+ ) ;
1666+
1667+ Ok ( MIGProfile {
1668+ inner : input. to_string ( ) ,
1669+ } )
1670+ }
1671+ }
1672+
1673+ string_impls_for ! ( MIGProfile , "MIGProfile" ) ;
1674+
1675+ #[ cfg( test) ]
1676+ mod test_valid_mig_profile {
1677+ use super :: MIGProfile ;
1678+ use std:: convert:: TryFrom ;
1679+
1680+ #[ test]
1681+ fn valid_mig_profile ( ) {
1682+ for ok in & [
1683+ "1g.5gb" , "2g.10gb" , "3g.20gb" , "7g.40gb" , "1g.10gb" , "1g.20gb" , "2g.20gb" , "3g.40gb" ,
1684+ "7g.80gb" , "1g.18gb" , "1g.35gb" , "2g.35gb" , "3g.71gb" , "7g.141gb" , "1" , "2" , "3" , "4" ,
1685+ "7" ,
1686+ ] {
1687+ assert ! ( MIGProfile :: try_from( * ok) . is_ok( ) ) ;
1688+ }
1689+ }
1690+
1691+ #[ test]
1692+ fn invalid_mig_profile ( ) {
1693+ assert ! ( MIGProfile :: try_from( "invalid" ) . is_err( ) ) ;
1694+ assert ! ( MIGProfile :: try_from( "1000" ) . is_err( ) ) ;
1695+ }
1696+ }
1697+
1698+ #[ model( impl_default = true ) ]
1699+ pub struct NvidiaMIGSettings {
1700+ profile : HashMap < NvidiaGPUModel , MIGProfile > ,
1701+ }
1702+
1703+ #[ cfg( test) ]
1704+ mod test_valid_nvidia_mig_settings {
1705+ use super :: * ;
1706+
1707+ #[ test]
1708+ fn test_serde_nvidia_device_plugins_with_mig_profile ( ) {
1709+ let test_json = r#"{"pass-device-specs":false,"device-id-strategy":"uuid","device-list-strategy":"envvar","device-partitioning-strategy":"mig","mig":{"profile":{"a100.40gb":"1g.5gb"}}}"# ;
1710+ let nvidia_device_plugins: NvidiaDevicePluginSettings =
1711+ serde_json:: from_str ( test_json) . unwrap ( ) ;
1712+ assert_eq ! (
1713+ nvidia_device_plugins,
1714+ NvidiaDevicePluginSettings {
1715+ pass_device_specs: Some ( false ) ,
1716+ device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
1717+ device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
1718+ device_sharing_strategy: None ,
1719+ time_slicing: None ,
1720+ device_partitioning_strategy: Some ( NvidiaDevicePartitioningStrategy :: MIG ) ,
1721+ mig: Some ( NvidiaMIGSettings {
1722+ profile: Some ( HashMap :: from( [ (
1723+ NvidiaGPUModel :: try_from( "a100.40gb" ) . unwrap( ) ,
1724+ MIGProfile :: try_from( "1g.5gb" ) . unwrap( )
1725+ ) ] ) )
1726+ } ) ,
1727+ }
1728+ ) ;
1729+
1730+ let results = serde_json:: to_string ( & nvidia_device_plugins) . unwrap ( ) ;
1731+ assert_eq ! ( results, test_json) ;
1732+ }
1733+ }
0 commit comments