Skip to content

Commit 224469b

Browse files
authored
test: [TRTLLM-4334] Create 1.0 criteria scope from API stability references (NVIDIA#3069)
* committed APIs validation Signed-off-by: Enwei Zhu <[email protected]> * fix Signed-off-by: Enwei Zhu <[email protected]> * clean name Signed-off-by: Enwei Zhu <[email protected]> * separate Signed-off-by: Enwei Zhu <[email protected]> * add TODOs Signed-off-by: Enwei Zhu <[email protected]> * fix naming Signed-off-by: Enwei Zhu <[email protected]> * fix Signed-off-by: Enwei Zhu <[email protected]> --------- Signed-off-by: Enwei Zhu <[email protected]>
1 parent ea3739e commit 224469b

20 files changed

+497
-620
lines changed

tensorrt_llm/_torch/llm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def __init__(self,
1717
skip_tokenizer_init: bool = False,
1818
trust_remote_code: bool = False,
1919
tensor_parallel_size: int = 1,
20-
pipeline_parallel_size: int = 1,
2120
dtype: str = "auto",
2221
revision: Optional[str] = None,
2322
tokenizer_revision: Optional[str] = None,
@@ -26,6 +25,5 @@ def __init__(self,
2625
kwargs_dict = dict(kwargs)
2726
kwargs_dict['backend'] = 'pytorch'
2827
super().__init__(model, tokenizer, tokenizer_mode, skip_tokenizer_init,
29-
trust_remote_code, tensor_parallel_size,
30-
pipeline_parallel_size, dtype, revision,
31-
tokenizer_revision, **kwargs_dict)
28+
trust_remote_code, tensor_parallel_size, dtype,
29+
revision, tokenizer_revision, **kwargs_dict)

tensorrt_llm/llmapi/llm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def __init__(self,
9898
skip_tokenizer_init: bool = False,
9999
trust_remote_code: bool = False,
100100
tensor_parallel_size: int = 1,
101-
pipeline_parallel_size: int = 1,
102101
dtype: str = "auto",
103102
revision: Optional[str] = None,
104103
tokenizer_revision: Optional[str] = None,
@@ -116,7 +115,6 @@ def __init__(self,
116115
skip_tokenizer_init=skip_tokenizer_init,
117116
trust_remote_code=trust_remote_code,
118117
tensor_parallel_size=tensor_parallel_size,
119-
pipeline_parallel_size=pipeline_parallel_size,
120118
dtype=dtype,
121119
revision=revision,
122120
tokenizer_revision=tokenizer_revision,

tensorrt_llm/llmapi/llm_args.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -650,8 +650,6 @@ def model_name(self) -> Union[str, Path]:
650650
651651
tensor_parallel_size(int): The number of processes for tensor parallelism. Defaults to 1.
652652
653-
pipeline_parallel_size(int): The number of processes for pipeline parallelism. Defaults to 1.
654-
655653
dtype (str): The data type for the model weights and activations. Defaults to "auto".
656654
Can be "float16", "bfloat16", "float32", or "auto". If "auto", the data type will be automatically inferred from the source model.
657655
If the source data type is "float32", it will be converted to "float16".
@@ -662,6 +660,8 @@ def model_name(self) -> Union[str, Path]:
662660
"""
663661

664662
LLMARGS_IMPLICIT_DOCSTRING = """
663+
pipeline_parallel_size(int): The number of processes for pipeline parallelism. Defaults to 1.
664+
665665
context_parallel_size (int): The context parallel size. Defaults to 1.
666666
667667
gpus_per_node (int, optional): The number of GPUs per node. None means automatic configure. Defaults to None.
@@ -769,15 +769,15 @@ class LlmArgs:
769769

770770
tensor_parallel_size: int = 1
771771

772-
pipeline_parallel_size: int = 1
773-
774772
dtype: str = "auto"
775773

776774
revision: Optional[str] = None
777775

778776
tokenizer_revision: Optional[str] = None
779777

780778
# Below are all remaining arguments
779+
pipeline_parallel_size: int = 1
780+
781781
context_parallel_size: int = 1
782782

783783
gpus_per_node: Optional[int] = None

tests/unittest/api_stability/api_stability_core.py

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# autoflake: skip_file
2+
import copy
23
import inspect
34
import os
45
import pathlib
@@ -28,13 +29,12 @@ def repr_annotation(field_type: type) -> str:
2829

2930
@dataclass(slots=True)
3031
class ParamSnapshot:
31-
name: str
3232
annotation: type
3333
default: Any = None
3434

3535
@classmethod
3636
def from_inspect(cls, param: inspect.Parameter):
37-
return cls(param.name, param.annotation, param.default)
37+
return cls(param.annotation, param.default)
3838

3939
@classmethod
4040
def from_docstring(cls, param: docstring_parser.common.DocstringParam):
@@ -57,7 +57,7 @@ def from_docstring(cls, param: docstring_parser.common.DocstringParam):
5757
except (NameError, SyntaxError):
5858
default = param.default
5959

60-
return cls(param.arg_name, annotation, default)
60+
return cls(annotation, default)
6161

6262
@classmethod
6363
def from_dict(cls, d: dict):
@@ -77,19 +77,17 @@ def to_dict(self):
7777
return d
7878

7979
def assert_equal(self, other: 'ParamSnapshot'):
80-
assert self.name == other.name
8180
assert self.annotation == other.annotation
8281
assert self.default == other.default
8382

8483

8584
@dataclass(slots=True)
8685
class MethodSnapshot:
87-
name: str
8886
parameters: Dict[str, ParamSnapshot]
8987
return_annotation: type
9088

9189
@classmethod
92-
def from_inspect(cls, name: str, method: MethodType):
90+
def from_inspect(cls, method: MethodType):
9391
signature = inspect.signature(method)
9492
parameters = {}
9593
for param_name, param in signature.parameters.items():
@@ -99,10 +97,10 @@ def from_inspect(cls, name: str, method: MethodType):
9997
return_annotation = signature.return_annotation
10098
if isinstance(return_annotation, str):
10199
return_annotation = eval(return_annotation)
102-
return cls(name, parameters, return_annotation)
100+
return cls(parameters, return_annotation)
103101

104102
@classmethod
105-
def from_docstring(cls, name: str, method: MethodType):
103+
def from_docstring(cls, method: MethodType):
106104
doc = docstring_parser.parse(method.__doc__)
107105
parameters = {}
108106
for param in doc.params:
@@ -112,7 +110,7 @@ def from_docstring(cls, name: str, method: MethodType):
112110
return_annotation = None
113111
else:
114112
return_annotation = eval(doc.returns.type_name)
115-
return cls(name, parameters, return_annotation)
113+
return cls(parameters, return_annotation)
116114

117115
@classmethod
118116
def from_dict(cls, d: dict):
@@ -132,13 +130,23 @@ def to_dict(self):
132130
d['return_annotation'] = repr_annotation(d['return_annotation'])
133131
return d
134132

133+
def merge(self, other: 'MethodSnapshot'):
134+
assert self.parameters.keys().isdisjoint(other.parameters.keys())
135+
self.parameters.update(copy.deepcopy(other.parameters))
136+
assert self.return_annotation == other.return_annotation
137+
135138
def assert_equal(self, other: 'MethodSnapshot'):
136-
assert self.name == other.name
137139
assert self.parameters.keys() == other.parameters.keys()
138140
for name, param in self.parameters.items():
139141
param.assert_equal(other.parameters[name])
140142
assert self.return_annotation == other.return_annotation
141143

144+
def assert_containing(self, other: 'MethodSnapshot'):
145+
for name, param in other.parameters.items():
146+
assert name in self.parameters
147+
self.parameters[name].assert_equal(param)
148+
assert self.return_annotation == other.return_annotation
149+
142150

143151
@dataclass(slots=True)
144152
class ClassSnapshot:
@@ -153,16 +161,14 @@ def from_inspect(cls, snapshot_cls: type):
153161
inst, predicate=inspect.ismethod):
154162
if method_name.startswith("_") and method_name != "__init__":
155163
continue
156-
methods[method_name] = MethodSnapshot.from_inspect(
157-
method_name, method)
164+
methods[method_name] = MethodSnapshot.from_inspect(method)
158165
properties = {}
159166
for prop_name, prop in inspect.getmembers(
160167
snapshot_cls, predicate=lambda x: isinstance(x, property)):
161168
if prop_name.startswith("_"):
162169
continue
163170
annotation = inspect.signature(prop.fget).return_annotation
164-
properties[prop_name] = ParamSnapshot(prop_name, annotation,
165-
inspect._empty)
171+
properties[prop_name] = ParamSnapshot(annotation, inspect._empty)
166172
return cls(methods, properties)
167173

168174
@classmethod
@@ -175,10 +181,9 @@ def from_docstring(cls, snapshot_cls: type):
175181
continue
176182
if method_name == "__init__":
177183
methods["__init__"] = MethodSnapshot.from_docstring(
178-
"__init__", snapshot_cls)
184+
snapshot_cls)
179185
else:
180-
methods[method_name] = MethodSnapshot.from_docstring(
181-
method_name, method)
186+
methods[method_name] = MethodSnapshot.from_docstring(method)
182187
properties = {}
183188
doc = docstring_parser.parse(snapshot_cls.__doc__)
184189
for param in doc.params:
@@ -210,6 +215,19 @@ def to_dict(self):
210215
}
211216
return d
212217

218+
def merge(self, other: 'ClassSnapshot'):
219+
for name, method in self.methods.items():
220+
if name in other.methods:
221+
method.merge(other.methods[name])
222+
new_methods = {
223+
name: method
224+
for name, method in other.methods.items()
225+
if name not in self.methods
226+
}
227+
self.methods.update(copy.deepcopy(new_methods))
228+
assert self.properties.keys().isdisjoint(other.properties.keys())
229+
self.properties.update(copy.deepcopy(other.properties))
230+
213231
def assert_equal(self, other: 'ClassSnapshot'):
214232
assert self.methods.keys() == other.methods.keys()
215233
for name, method in self.methods.items():
@@ -218,30 +236,47 @@ def assert_equal(self, other: 'ClassSnapshot'):
218236
for name, prop in self.properties.items():
219237
prop.assert_equal(other.properties[name])
220238

239+
def assert_containing(self, other: 'ClassSnapshot'):
240+
for name, method in other.methods.items():
241+
assert name in self.methods
242+
self.methods[name].assert_containing(method)
243+
for name, prop in other.properties.items():
244+
assert name in self.properties
245+
self.properties[name].assert_equal(prop)
246+
221247

222248
class ApiStabilityTestHarness:
223249
TEST_CLASS = None
250+
REFERENCE_COMMITTED_DIR = f"{os.path.dirname(__file__)}/references_committed"
224251
REFERENCE_DIR = f"{os.path.dirname(__file__)}/references"
225252
REFERENCE_FILE = None
226253

227-
@classmethod
228-
def reference_path(cls):
229-
return f"{cls.REFERENCE_DIR}/{cls.REFERENCE_FILE}"
230-
231254
@classmethod
232255
def setup_class(cls):
233-
with open(cls.reference_path()) as f:
256+
with open(f"{cls.REFERENCE_DIR}/{cls.REFERENCE_FILE}") as f:
234257
cls.reference = ClassSnapshot.from_dict(yaml.safe_load(f))
235-
cls.error_msg = (
236-
f"API stability validation failed. "
237-
f"This is probably because you changed {cls.TEST_CLASS.__name__}'s APIs, please ask for reviews from the code owners."
238-
)
258+
if os.path.exists(
259+
f"{cls.REFERENCE_COMMITTED_DIR}/{cls.REFERENCE_FILE}"):
260+
with open(
261+
f"{cls.REFERENCE_COMMITTED_DIR}/{cls.REFERENCE_FILE}") as f:
262+
cls.reference_committed = ClassSnapshot.from_dict(
263+
yaml.safe_load(f))
264+
cls.reference.merge(cls.reference_committed)
265+
else:
266+
cls.reference_committed = None
267+
cls.error_msg = f"API validation failed because you changed {cls.TEST_CLASS.__name__}'s APIs, please ask for reviews from the code owners."
268+
cls.error_msg_committed = f"API validation failed because you changed {cls.TEST_CLASS.__name__}'s committed APIs, please ask for approval."
239269

240270
def create_snapshot_from_inspect(self):
241271
return ClassSnapshot.from_inspect(self.TEST_CLASS)
242272

243273
def test_signature(self):
244274
snapshot = self.create_snapshot_from_inspect()
275+
if self.reference_committed is not None:
276+
try:
277+
snapshot.assert_containing(self.reference_committed)
278+
except AssertionError as e:
279+
raise AssertionError(self.error_msg_committed) from e
245280
try:
246281
snapshot.assert_equal(self.reference)
247282
except AssertionError as e:
@@ -252,6 +287,11 @@ def create_snapshot_from_docstring(self):
252287

253288
def test_docstring(self):
254289
snapshot = self.create_snapshot_from_docstring()
290+
if self.reference_committed is not None:
291+
try:
292+
snapshot.assert_containing(self.reference_committed)
293+
except AssertionError as e:
294+
raise AssertionError(self.error_msg_committed) from e
255295
try:
256296
snapshot.assert_equal(self.reference)
257297
except AssertionError as e:
Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
11
methods:
22
__call__:
3-
name: __call__
43
parameters:
54
client_ids:
65
annotation: List[Optional[int]]
76
default: inspect._empty
8-
name: client_ids
97
logits:
108
annotation: List[torch.Tensor]
119
default: inspect._empty
12-
name: logits
1310
req_ids:
1411
annotation: List[int]
1512
default: inspect._empty
16-
name: req_ids
1713
stream_ptr:
1814
annotation: int
1915
default: inspect._empty
20-
name: stream_ptr
2116
token_ids:
2217
annotation: List[List[List[int]]]
2318
default: inspect._empty
24-
name: token_ids
2519
return_annotation: None
2620
properties: {}
Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,35 @@
11
methods:
22
__init__:
3-
name: __init__
43
parameters:
54
calib_batch_size:
65
annotation: int
76
default: 1
8-
name: calib_batch_size
97
calib_batches:
108
annotation: int
119
default: 512
12-
name: calib_batches
1310
calib_dataset:
1411
annotation: str
1512
default: cnn_dailymail
16-
name: calib_dataset
1713
calib_max_seq_length:
1814
annotation: int
1915
default: 512
20-
name: calib_max_seq_length
2116
device:
2217
annotation: Literal['cuda', 'cpu']
2318
default: cuda
24-
name: device
2519
random_seed:
2620
annotation: int
2721
default: 1234
28-
name: random_seed
2922
tokenizer_max_seq_length:
3023
annotation: int
3124
default: 2048
32-
name: tokenizer_max_seq_length
3325
return_annotation: None
3426
from_dict:
35-
name: from_dict
3627
parameters:
3728
config:
3829
annotation: dict
3930
default: inspect._empty
40-
name: config
4131
return_annotation: tensorrt_llm.llmapi.llm_utils.CalibConfig
4232
to_dict:
43-
name: to_dict
4433
parameters: {}
4534
return_annotation: dict
4635
properties: {}

0 commit comments

Comments
 (0)