Skip to content

Commit 73bfd4c

Browse files
committed
feat(//core/conversion/converter/Arg): Add typechecking to the unwrap
functions Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8be79e1 commit 73bfd4c

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

core/conversion/converters/Arg.cpp

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ std::string Arg::type_name() const {
8585
default:
8686
return "None";
8787
}
88-
88+
8989
}
90-
90+
9191
const torch::jit::IValue* Arg::IValue() const {
9292
if (type_ == Type::kIValue) {
9393
return ptr_.ivalue;
@@ -150,7 +150,7 @@ double Arg::unwrapToDouble(double default_val) {
150150

151151
double Arg::unwrapToDouble() {
152152
return this->unwrapTo<double>();
153-
}
153+
}
154154

155155
bool Arg::unwrapToBool(bool default_val) {
156156
return this->unwrapTo<bool>(default_val);
@@ -194,26 +194,41 @@ c10::List<bool> Arg::unwrapToBoolList() {
194194

195195
template<typename T>
196196
T Arg::unwrapTo(T default_val) {
197-
if (isIValue()) {
198-
// TODO: implement Tag Checking
199-
return ptr_.ivalue->to<T>();
197+
try {
198+
return this->unwrapTo<T>();
199+
} catch(trtorch::Error& e) {
200+
LOG_DEBUG("In arg unwrapping, returning default value provided (" << e.what() << ")");
201+
return default_val;
200202
}
201-
LOG_DEBUG("In arg unwrapping, returning default value provided");
202-
return default_val;
203203
}
204204

205-
206205
template<typename T>
207206
T Arg::unwrapTo() {
208-
if (isIValue()) {
209-
//TODO: Implement Tag checking
210-
return ptr_.ivalue->to<T>();
211-
//TODO: Exception
212-
//LOG_INTERNAL_ERROR("Requested unwrapping of arg IValue assuming it was " << typeid(T).name() << " however type is " << ptr_.ivalue->type());
213-
207+
TRTORCH_CHECK(isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
208+
auto ivalue = ptr_.ivalue;
209+
bool correct_type = false;
210+
if (typeid(T) == typeid(double)) {
211+
correct_type = ivalue->isDouble();
212+
} else if (typeid(T) == typeid(bool)) {
213+
correct_type = ivalue->isBool();
214+
} else if (typeid(T) == typeid(int64_t)) {
215+
correct_type = ivalue->isInt();
216+
} else if (typeid(T) == typeid(at::Tensor)) {
217+
correct_type = ivalue->isTensor();
218+
} else if (typeid(T) == typeid(c10::Scalar)) {
219+
correct_type = ivalue->isScalar();
220+
} else if (typeid(T) == typeid(c10::List<int64_t>)) {
221+
correct_type = ivalue->isIntList();
222+
} else if (typeid(T) == typeid(c10::List<double>)) {
223+
correct_type = ivalue->isDoubleList();
224+
} else if (typeid(T) == typeid(c10::List<bool>)) {
225+
correct_type = ivalue->isBoolList();
226+
} else {
227+
TRTORCH_THROW_ERROR("Requested unwrapping of arg to an unsupported type: " << typeid(T).name());
214228
}
215-
TRTORCH_THROW_ERROR("Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
216-
return T();
229+
230+
TRTORCH_CHECK(correct_type, "Requested unwrapping of arg IValue assuming it was " << typeid(T).name() << " however type is " << *(ptr_.ivalue->type()));
231+
return ptr_.ivalue->to<T>();
217232
}
218233

219234

0 commit comments

Comments
 (0)