|
3 | 3 | from copy import copy |
4 | 4 | import numpy as np |
5 | 5 | import io |
| 6 | +from collections import defaultdict |
6 | 7 |
|
7 | 8 | from .calibration import ( |
8 | 9 | TensorBatchDataset, |
@@ -326,10 +327,43 @@ def default_input_names(num_inputs): |
326 | 327 |
|
327 | 328 | def default_output_names(num_outputs): |
328 | 329 | return ["output_%d" % i for i in range(num_outputs)] |
329 | | - |
| 330 | + |
| 331 | + |
| 332 | +class LayerNamingNetworkWrapper(object): |
| 333 | + def __init__(self, ctx, network): |
| 334 | + self._ctx = ctx |
| 335 | + self._network = network |
| 336 | + self._layer_counts = defaultdict(lambda: 0) |
| 337 | + |
| 338 | + def _set_layer_name(self, layer): |
| 339 | + def arg_str(arg): |
| 340 | + if isinstance(arg, torch.Tensor): |
| 341 | + return "tensor(shape=%s, dtype=%s)" % (str(list(arg.shape)), str(arg.dtype)) |
| 342 | + return str(arg) |
| 343 | + |
| 344 | + self._layer_counts[layer.type.name] += 1 |
| 345 | + args = [arg_str(arg) for arg in self._ctx.method_args] |
| 346 | + kwargs = ["%s=%s" % (key, arg_str(arg)) for key, arg in self._ctx.method_kwargs.items()] |
| 347 | + layer.name = "[%s #%d] %s(%s)" % (layer.type.name, self._layer_counts[layer.type.name], |
| 348 | + self._ctx.method_str, ", ".join(args + kwargs)) |
| 349 | + |
| 350 | + def __getattr__(self, name): |
| 351 | + attr = getattr(self._network, name) |
| 352 | + if callable(attr): |
| 353 | + def wrapper(*args, **kwargs): |
| 354 | + ret = attr(*args, **kwargs) |
| 355 | + if isinstance(ret, trt.ILayer): |
| 356 | + self._set_layer_name(ret) |
| 357 | + return ret |
| 358 | + |
| 359 | + return wrapper |
| 360 | + else: |
| 361 | + return attr |
| 362 | + |
| 363 | + |
330 | 364 | class ConversionContext(object): |
331 | 365 | def __init__(self, network, converters=CONVERTERS): |
332 | | - self.network = network |
| 366 | + self.network = LayerNamingNetworkWrapper(self, network) |
333 | 367 | self.lock = False |
334 | 368 | self.method_args = None |
335 | 369 | self.method_kwargs = None |
|
0 commit comments