Skip to content

Commit bb75932

Browse files
authored
Merge pull request #1399 from pytorch/fix_missing_type
fix: fix missing float type in shape analysis
2 parents 6c56aac + 7330964 commit bb75932

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

core/partitioning/shape_analysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ void getSegmentsOutputByRunning(
102102
jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());
103103
} else if (input->type()->isSubtypeOf(torch::jit::BoolType::get())) {
104104
jit_inputs_ivalues.push_back(ivalues_maps[input].toBool());
105+
} else if (input->type()->isSubtypeOf(torch::jit::FloatType::get())) {
106+
jit_inputs_ivalues.push_back(ivalues_maps[input].toDouble());
105107
} else if (input->type()->kind() == torch::jit::TypeKind::ListType) {
106108
// create list
107109
jit_inputs_ivalues.push_back(ivalues_maps[input].toList());

0 commit comments

Comments
 (0)