@@ -1443,6 +1443,14 @@ def predict(self, items: List[float]) -> List[float]:
14431443 return items
14441444
14451445
1446+ class SubclassImplementsInterface (Interface ):
1447+ def __init__ (self , max_items : int ):
1448+ self .max_items = max_items
1449+
1450+ def predict (self , items : List [float ]) -> List [float ]:
1451+ return items
1452+
1453+
14461454class NotImplementsInterface1 :
14471455 def predict (self , items : str ) -> List [float ]:
14481456 return []
@@ -1462,6 +1470,7 @@ def predict(self, items: List[float]) -> None:
14621470 "expected, value" ,
14631471 [
14641472 (True , ImplementsInterface ),
1473+ (True , SubclassImplementsInterface ),
14651474 (False , ImplementsInterface (1 )),
14661475 (False , NotImplementsInterface1 ),
14671476 (False , NotImplementsInterface2 ),
@@ -1488,14 +1497,22 @@ def test_is_instance_or_supports_protocol(expected, value):
14881497
14891498def test_parse_implements_protocol (parser ):
14901499 parser .add_argument ("--cls" , type = Interface )
1491- assert "known subclasses:" not in get_parser_help (parser )
14921500 cfg = parser .parse_args ([f"--cls={ __name__ } .ImplementsInterface" , "--cls.batch_size=5" ])
14931501 assert cfg .cls .class_path == f"{ __name__ } .ImplementsInterface"
14941502 assert cfg .cls .init_args == Namespace (batch_size = 5 )
14951503 init = parser .instantiate_classes (cfg )
14961504 assert isinstance (init .cls , ImplementsInterface )
14971505 assert init .cls .batch_size == 5
14981506 assert init .cls .predict ([1.0 , 2.0 ]) == [1.0 , 2.0 ]
1507+
1508+ help_str = get_parser_help (parser )
1509+ assert "known subclasses:" in help_str
1510+ assert f"{ __name__ } .SubclassImplementsInterface" in help_str
1511+ help_str = get_parse_args_stdout (parser , [f"--cls.help={ __name__ } .SubclassImplementsInterface" ])
1512+ assert "--cls.max_items" in help_str
1513+ with pytest .raises (ArgumentError , match = "not a subclass or implementer of protocol" ):
1514+ parser .parse_args ([f"--cls.help={ __name__ } .NotImplementsInterface1" ])
1515+
14991516 with pytest .raises (ArgumentError , match = "is a protocol" ):
15001517 parser .parse_args ([f"--cls={ __name__ } .Interface" ])
15011518 with pytest .raises (ArgumentError , match = "does not implement protocol" ):
@@ -1551,13 +1568,17 @@ def test_implements_callable_protocol(expected, value):
15511568
15521569def test_parse_implements_callable_protocol (parser ):
15531570 parser .add_argument ("--cls" , type = CallableInterface )
1554- assert "known subclasses:" not in get_parser_help (parser )
15551571 cfg = parser .parse_args ([f"--cls={ __name__ } .ImplementsCallableInterface" , "--cls.batch_size=7" ])
15561572 assert cfg .cls .class_path == f"{ __name__ } .ImplementsCallableInterface"
15571573 assert cfg .cls .init_args == Namespace (batch_size = 7 )
15581574 init = parser .instantiate_classes (cfg )
15591575 assert isinstance (init .cls , ImplementsCallableInterface )
15601576 assert init .cls ([1.0 , 2.0 ]) == [1.0 , 2.0 ]
1577+
1578+ assert "known subclasses:" not in get_parser_help (parser )
1579+ help_str = get_parse_args_stdout (parser , [f"--cls.help={ __name__ } .ImplementsCallableInterface" ])
1580+ assert "--cls.batch_size" in help_str
1581+
15611582 with pytest .raises (ArgumentError , match = "is a protocol" ):
15621583 parser .parse_args ([f"--cls={ __name__ } .CallableInterface" ])
15631584 with pytest .raises (ArgumentError , match = "does not implement protocol" ):
0 commit comments