Skip to content

Commit 04ded55

Browse files
committed
chore: play around with aten::size
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent ccad996 commit 04ded55

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

core/conversion/converters/impl/shuffle.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
7373
std::cout << "====2====" << std::endl;
7474
std::vector<int64_t> new_shape;
7575
if (ctx->input_is_dynamic) {
76-
std::cout << "====3====" << std::endl;
77-
new_shape = util::toVec(args[1].unwrapToIntList().vec());
76+
std::cout << "====3====: " << args[1].size() << std::endl;
77+
// new_shape = util::toVec(args[1].unwrapToIntList().vec());
78+
new_shape = util::toVec(args[1].unwrapToITensorList());
7879
std::cout << "====4====" << std::endl;
7980
int nbDynamicDims = 0;
8081
for (size_t i = 0; i < new_shape.size(); i++) {

core/conversion/var/Var.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,24 @@ bool Var::isITensor() const {
146146
}
147147
}
148148

149+
bool Var::isITensorList() const {
150+
LOG_DEBUG("===== TYPE NAME: " << type_name());
151+
if (type_ == Type::kITensor) {
152+
return true;
153+
} else {
154+
return false;
155+
}
156+
}
157+
158+
bool Var::unwrapToITensorList() {
159+
TORCHTRT_CHECK(
160+
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
161+
LOG_DEBUG("===== TYPE NAME: " << type_name());
162+
auto ivalue = ptr_.ivalue;
163+
return false;
164+
// return ptr_.ivalue->to<nvinfer1::ITensor*>();
165+
}
166+
149167
bool Var::isIValue() const {
150168
if (type_ == Type::kIValue) {
151169
return true;

core/conversion/var/Var.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class Var : torch::CustomClassHolder {
4343
c10::Scalar unwrapToScalar();
4444
c10::List<int64_t> unwrapToIntList(c10::List<int64_t> default_val);
4545
c10::List<int64_t> unwrapToIntList();
46+
c10::List<nvinfer1::ITensor*> unwrapToITensorList();
4647
c10::List<double> unwrapToDoubleList(c10::List<double> default_val);
4748
c10::List<double> unwrapToDoubleList();
4849
c10::List<bool> unwrapToBoolList(c10::List<bool> default_val);
@@ -58,6 +59,7 @@ class Var : torch::CustomClassHolder {
5859

5960
bool isIValue() const;
6061
bool isITensor() const;
62+
bool isITensorList() const;
6163
bool isNone() const;
6264
Var::Type type() const;
6365
std::string type_name() const;

0 commit comments

Comments
 (0)