|
| 1 | +# trt_handler_numpy_trt10.py |
1 | 2 | # Licensed to the Apache Software Foundation (ASF) under one or more |
2 | | -# contributor license agreements. See the NOTICE file distributed with |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
3 | 4 | # this work for additional information regarding copyright ownership. |
4 | | -# The ASF licenses this file to You under the Apache License, Version 2.0 |
5 | | -# (the "License"); you may not use this file except in compliance with |
6 | | -# the License. You may obtain a copy of the License at |
7 | | -# |
8 | | -# http://www.apache.org/licenses/LICENSE-2.0 |
9 | | -# |
10 | | -# Unless required by applicable law or agreed to in writing, software |
11 | | -# distributed under the License is distributed on an "AS IS" BASIS, |
12 | | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | | -# See the License for the specific language governing permissions and |
14 | | -# limitations under the License. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0. |
15 | 6 |
|
16 | 7 | from __future__ import annotations |
17 | 8 |
|
@@ -193,7 +184,8 @@ def _resolve_output_shape(shape: Optional[Sequence[int]], |
193 | 184 | return tuple(shp) |
194 | 185 |
|
195 | 186 |
|
196 | | -def _to_contiguous_batch(x: Union[Sequence[np.ndarray], np.ndarray]) -> np.ndarray: |
| 187 | +def _to_contiguous_batch( |
| 188 | + x: Union[Sequence[np.ndarray], np.ndarray]) -> np.ndarray: |
197 | 189 | """ |
198 | 190 | Accept either an ndarray (already a batch) or a list of ndarrays (concat on axis 0). |
199 | 191 | This avoids accidental rank-5 shapes from upstream batching. |
@@ -518,7 +510,8 @@ def run_inference( |
518 | 510 | ) -> Iterable[PredictionResult]: |
519 | 511 | return self.inference_fn(batch, model, inference_args) |
520 | 512 |
|
521 | | - def get_num_bytes(self, batch: Union[Sequence[np.ndarray], np.ndarray]) -> int: |
| 513 | + def get_num_bytes( |
| 514 | + self, batch: Union[Sequence[np.ndarray], np.ndarray]) -> int: |
522 | 515 | if isinstance(batch, np.ndarray): |
523 | 516 | return int(batch.nbytes) |
524 | 517 | if isinstance(batch, (list, tuple)) and all(isinstance(a, np.ndarray) |
|
0 commit comments