@@ -85,9 +85,9 @@ std::string Arg::type_name() const {
85
85
default :
86
86
return " None" ;
87
87
}
88
-
88
+
89
89
}
90
-
90
+
91
91
const torch::jit::IValue* Arg::IValue () const {
92
92
if (type_ == Type::kIValue ) {
93
93
return ptr_.ivalue ;
@@ -150,7 +150,7 @@ double Arg::unwrapToDouble(double default_val) {
150
150
151
151
double Arg::unwrapToDouble () {
152
152
return this ->unwrapTo <double >();
153
- }
153
+ }
154
154
155
155
bool Arg::unwrapToBool (bool default_val) {
156
156
return this ->unwrapTo <bool >(default_val);
@@ -194,26 +194,41 @@ c10::List<bool> Arg::unwrapToBoolList() {
194
194
195
195
template <typename T>
196
196
T Arg::unwrapTo (T default_val) {
197
- if (isIValue ()) {
198
- // TODO: implement Tag Checking
199
- return ptr_.ivalue ->to <T>();
197
+ try {
198
+ return this ->unwrapTo <T>();
199
+ } catch (trtorch::Error& e) {
200
+ LOG_DEBUG (" In arg unwrapping, returning default value provided (" << e.what () << " )" );
201
+ return default_val;
200
202
}
201
- LOG_DEBUG (" In arg unwrapping, returning default value provided" );
202
- return default_val;
203
203
}
204
204
205
-
206
205
template <typename T>
207
206
T Arg::unwrapTo () {
208
- if (isIValue ()) {
209
- // TODO: Implement Tag checking
210
- return ptr_.ivalue ->to <T>();
211
- // TODO: Exception
212
- // LOG_INTERNAL_ERROR("Requested unwrapping of arg IValue assuming it was " << typeid(T).name() << " however type is " << ptr_.ivalue->type());
213
-
207
+ TRTORCH_CHECK (isIValue (), " Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name ());
208
+ auto ivalue = ptr_.ivalue ;
209
+ bool correct_type = false ;
210
+ if (typeid (T) == typeid (double )) {
211
+ correct_type = ivalue->isDouble ();
212
+ } else if (typeid (T) == typeid (bool )) {
213
+ correct_type = ivalue->isBool ();
214
+ } else if (typeid (T) == typeid (int64_t )) {
215
+ correct_type = ivalue->isInt ();
216
+ } else if (typeid (T) == typeid (at::Tensor)) {
217
+ correct_type = ivalue->isTensor ();
218
+ } else if (typeid (T) == typeid (c10::Scalar)) {
219
+ correct_type = ivalue->isScalar ();
220
+ } else if (typeid (T) == typeid (c10::List<int64_t >)) {
221
+ correct_type = ivalue->isIntList ();
222
+ } else if (typeid (T) == typeid (c10::List<double >)) {
223
+ correct_type = ivalue->isDoubleList ();
224
+ } else if (typeid (T) == typeid (c10::List<bool >)) {
225
+ correct_type = ivalue->isBoolList ();
226
+ } else {
227
+ TRTORCH_THROW_ERROR (" Requested unwrapping of arg to an unsupported type: " << typeid (T).name ());
214
228
}
215
- TRTORCH_THROW_ERROR (" Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name ());
216
- return T ();
229
+
230
+ TRTORCH_CHECK (correct_type, " Requested unwrapping of arg IValue assuming it was " << typeid (T).name () << " however type is " << *(ptr_.ivalue ->type ()));
231
+ return ptr_.ivalue ->to <T>();
217
232
}
218
233
219
234
0 commit comments