Skip to content

Commit 164a1a6

Browse files
committed
feat(trt_util): from Naren, added unpadDims tool
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent d33ec82 commit 164a1a6

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

core/util/trt_util.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,30 @@ nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to) {
8282
return dims;
8383
}
8484

85+
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) {
86+
nvinfer1::Dims dims;
87+
88+
int j = 0;
89+
bool pad_dims_done = false;
90+
91+
for (int i = 0; i < d.nbDims; i++) {
92+
if (d.d[i] == 1 && !pad_dims_done) {
93+
// skip over unecessary dimension
94+
continue;
95+
} else {
96+
dims.d[j] = d.d[i];
97+
j++;
98+
99+
// keep all other dimensions (don't skip over them)
100+
pad_dims_done = true;
101+
}
102+
}
103+
104+
dims.nbDims = j;
105+
106+
return dims;
107+
}
108+
85109
std::vector<int64_t> toVec(nvinfer1::Dims d) {
86110
std::vector<int64_t> dims;
87111
for (int i = 0; i < d.nbDims; i++) {

core/util/trt_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ int64_t volume(const nvinfer1::Dims& d);
7979

8080
nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to);
8181
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
82+
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);
8283
nvinfer1::Dims toDims(c10::IntArrayRef l);
8384
nvinfer1::Dims toDims(c10::List<int64_t> l);
8485
nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l);

0 commit comments

Comments
 (0)