Skip to content

Commit 74dd00e

Browse files
committed
fix: fix missing float type in shape analysis
Signed-off-by: Bo Wang <[email protected]>
1 parent bf2054e commit 74dd00e

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

core/partitioning/shape_analysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ void getSegmentsOutputByRunning(
102102
jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());
103103
} else if (input->type()->isSubtypeOf(torch::jit::BoolType::get())) {
104104
jit_inputs_ivalues.push_back(ivalues_maps[input].toBool());
105+
} else if (input->type()->isSubtypeOf(torch::jit::FloatType::get())) {
106+
jit_inputs_ivalues.push_back(ivalues_maps[input].toDouble());
105107
} else if (input->type()->kind() == torch::jit::TypeKind::ListType) {
106108
// create list
107109
jit_inputs_ivalues.push_back(ivalues_maps[input].toList());

0 commit comments

Comments
 (0)