File tree Expand file tree Collapse file tree 3 files changed +23
-2
lines changed Expand file tree Collapse file tree 3 files changed +23
-2
lines changed Original file line number Diff line number Diff line change @@ -73,8 +73,9 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
73
73
std::cout << " ====2====" << std::endl;
74
74
std::vector<int64_t > new_shape;
75
75
if (ctx->input_is_dynamic ) {
76
- std::cout << " ====3====" << std::endl;
77
- new_shape = util::toVec (args[1 ].unwrapToIntList ().vec ());
76
+ std::cout << " ====3====: " << args[1 ].size () << std::endl;
77
+ // new_shape = util::toVec(args[1].unwrapToIntList().vec());
78
+ new_shape = util::toVec (args[1 ].unwrapToITensorList ());
78
79
std::cout << " ====4====" << std::endl;
79
80
int nbDynamicDims = 0 ;
80
81
for (size_t i = 0 ; i < new_shape.size (); i++) {
Original file line number Diff line number Diff line change @@ -146,6 +146,24 @@ bool Var::isITensor() const {
146
146
}
147
147
}
148
148
149
+ bool Var::isITensorList () const {
150
+ LOG_DEBUG (" ===== TYPE NAME: " << type_name ());
151
+ if (type_ == Type::kITensor ) {
152
+ return true ;
153
+ } else {
154
+ return false ;
155
+ }
156
+ }
157
+
158
+ bool Var::unwrapToITensorList () {
159
+ TORCHTRT_CHECK (
160
+ isIValue (), " Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name ());
161
+ LOG_DEBUG (" ===== TYPE NAME: " << type_name ());
162
+ auto ivalue = ptr_.ivalue ;
163
+ return false ;
164
+ // return ptr_.ivalue->to<nvinfer1::ITensor*>();
165
+ }
166
+
149
167
bool Var::isIValue () const {
150
168
if (type_ == Type::kIValue ) {
151
169
return true ;
Original file line number Diff line number Diff line change @@ -43,6 +43,7 @@ class Var : torch::CustomClassHolder {
43
43
c10::Scalar unwrapToScalar ();
44
44
c10::List<int64_t > unwrapToIntList (c10::List<int64_t > default_val);
45
45
c10::List<int64_t > unwrapToIntList ();
46
+ c10::List<nvinfer1::ITensor*> unwrapToITensorList ();
46
47
c10::List<double > unwrapToDoubleList (c10::List<double > default_val);
47
48
c10::List<double > unwrapToDoubleList ();
48
49
c10::List<bool > unwrapToBoolList (c10::List<bool > default_val);
@@ -58,6 +59,7 @@ class Var : torch::CustomClassHolder {
58
59
59
60
bool isIValue () const ;
60
61
bool isITensor () const ;
62
+ bool isITensorList () const ;
61
63
bool isNone () const ;
62
64
Var::Type type () const ;
63
65
std::string type_name () const ;
You can’t perform that action at this time.
0 commit comments