@@ -1469,6 +1469,8 @@ pub struct NvidiaDevicePluginSettings {
14691469 device_list_strategy : NvidiaDeviceListStrategy ,
14701470 device_sharing_strategy : NvidiaDeviceSharingStrategy ,
14711471 time_slicing : NvidiaTimeSlicingSettings ,
1472+ device_partitioning_strategy : NvidiaDevicePartitioningStrategy ,
1473+ mig : NvidiaMIGSettings ,
14721474}
14731475
14741476#[ derive( Debug , Clone , Eq , PartialEq , Hash , Serialize , Deserialize ) ]
@@ -1515,7 +1517,9 @@ mod tests {
15151517 device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
15161518 device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
15171519 device_sharing_strategy: None ,
1518- time_slicing: None
1520+ time_slicing: None ,
1521+ device_partitioning_strategy: None ,
1522+ mig: None
15191523 }
15201524 ) ;
15211525 let results = serde_json:: to_string ( & nvidia_device_plugins) . unwrap ( ) ;
@@ -1534,7 +1538,9 @@ mod tests {
15341538 device_id_strategy: Some ( NvidiaDeviceIdStrategy :: Uuid ) ,
15351539 device_list_strategy: Some ( NvidiaDeviceListStrategy :: Envvar ) ,
15361540 device_sharing_strategy: Some ( NvidiaDeviceSharingStrategy :: TimeSlicing ) ,
1537- time_slicing: None
1541+ time_slicing: None ,
1542+ device_partitioning_strategy: None ,
1543+ mig: None
15381544 }
15391545 ) ;
15401546
@@ -1549,3 +1555,120 @@ mod tests {
15491555 assert ! ( result. is_err( ) , "The JSON should not be parsed successfully as it contains an invalid value for 'replicas'." ) ;
15501556 }
15511557}
1558+
1559+ #[ derive( Debug , Clone , PartialEq , Serialize , Deserialize , Default ) ]
1560+ #[ serde( rename_all = "lowercase" ) ]
1561+ pub enum NvidiaDevicePartitioningStrategy {
1562+ #[ default]
1563+ None ,
1564+ MIG ,
1565+ }
1566+
1567+ #[ model( impl_default = true ) ]
1568+ pub struct NvidiaMIGSettings {
1569+ profile : HashMap < NvidiaGPUModel , MIGProfile > ,
1570+ }
1571+
1572+ #[ derive( Debug , Clone , Eq , PartialEq , Hash ) ]
1573+ pub struct NvidiaGPUModel {
1574+ inner : String ,
1575+ }
1576+
1577+ lazy_static ! {
1578+ pub ( crate ) static ref NVIDIAGPU_NAME : Regex = Regex :: new( r"^([a-z])(\d+)\.(\d+)gb$" ) . unwrap( ) ;
1579+ }
1580+
1581+ impl TryFrom < & str > for NvidiaGPUModel {
1582+ type Error = error:: Error ;
1583+
1584+ fn try_from ( input : & str ) -> Result < Self , Self :: Error > {
1585+ ensure ! (
1586+ NVIDIAGPU_NAME . is_match( input) ,
1587+ error:: PatternSnafu {
1588+ thing: "NVIDIA GPU Model" ,
1589+ pattern: NVIDIAGPU_NAME . clone( ) ,
1590+ input
1591+ }
1592+ ) ;
1593+
1594+ Ok ( NvidiaGPUModel {
1595+ inner : input. to_string ( ) ,
1596+ } )
1597+ }
1598+ }
1599+
1600+ string_impls_for ! ( NvidiaGPUModel , "NvidiaGPUModel" ) ;
1601+
1602+ #[ cfg( test) ]
1603+ mod test_valid_gpu_model {
1604+ use super :: NvidiaGPUModel ;
1605+ use std:: convert:: TryFrom ;
1606+
1607+ #[ test]
1608+ fn valid_gpu_model ( ) {
1609+ for ok in & [ "a100.40gb" , "a100.80gb" , "h100.80gb" , "h100.141gb" ] {
1610+ assert ! ( NvidiaGPUModel :: try_from( * ok) . is_ok( ) ) ;
1611+ }
1612+ }
1613+
1614+ #[ test]
1615+ fn invalid_gpu_model ( ) {
1616+ assert ! ( NvidiaGPUModel :: try_from( "invalid" ) . is_err( ) ) ;
1617+ assert ! ( NvidiaGPUModel :: try_from( "1000" ) . is_err( ) ) ;
1618+ }
1619+ }
1620+
1621+ #[ derive( Debug , Clone , Eq , PartialEq , Hash ) ]
1622+ pub struct MIGProfile {
1623+ inner : String ,
1624+ }
1625+
1626+ lazy_static ! {
1627+ pub ( crate ) static ref MIGPROFILE_NAME : Regex = Regex :: new( r"^[0-9]g\.\d+gb$" ) . unwrap( ) ;
1628+ }
1629+
1630+ impl TryFrom < & str > for MIGProfile {
1631+ type Error = error:: Error ;
1632+
1633+ fn try_from ( input : & str ) -> Result < Self , Self :: Error > {
1634+ let slice_format = matches ! ( input, "1" | "2" | "3" | "4" | "7" ) ;
1635+
1636+ ensure ! (
1637+ slice_format | MIGPROFILE_NAME . is_match( input) ,
1638+ error:: PatternSnafu {
1639+ thing: "MIG Profile" ,
1640+ pattern: MIGPROFILE_NAME . clone( ) ,
1641+ input
1642+ }
1643+ ) ;
1644+
1645+ Ok ( MIGProfile {
1646+ inner : input. to_string ( ) ,
1647+ } )
1648+ }
1649+ }
1650+
1651+ string_impls_for ! ( MIGProfile , "MIGProfile" ) ;
1652+
1653+ #[ cfg( test) ]
1654+ mod test_valid_mig_profile {
1655+ use super :: MIGProfile ;
1656+ use std:: convert:: TryFrom ;
1657+
1658+ #[ test]
1659+ fn valid_mig_profile ( ) {
1660+ for ok in & [
1661+ "1g.5gb" , "2g.10gb" , "3g.20gb" , "7g.40gb" , "1g.10gb" , "1g.20gb" , "2g.20gb" , "3g.40gb" ,
1662+ "7g.80gb" , "1g.18gb" , "1g.35gb" , "2g.35gb" , "3g.71gb" , "7g.141gb" , "1" , "2" , "3" , "4" ,
1663+ "7" ,
1664+ ] {
1665+ assert ! ( MIGProfile :: try_from( * ok) . is_ok( ) ) ;
1666+ }
1667+ }
1668+
1669+ #[ test]
1670+ fn invalid_mig_profile ( ) {
1671+ assert ! ( MIGProfile :: try_from( "invalid" ) . is_err( ) ) ;
1672+ assert ! ( MIGProfile :: try_from( "1000" ) . is_err( ) ) ;
1673+ }
1674+ }
0 commit comments