Skip to content

Commit ab977f5

Browse files
committed
fix: deal with edge cases when introduced value is Tensor with gradient
Signed-off-by: Bo Wang <[email protected]>
1 parent 05bf80c commit ab977f5

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

core/compiler.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
#include "ATen/core/jit_type.h"
1111

1212
#include "torch/csrc/jit/frontend/function_schema_parser.h"
13-
#include "torch/csrc/jit/ir/ir.h"
1413
#include "torch/csrc/jit/ir/constants.h"
14+
#include "torch/csrc/jit/ir/ir.h"
1515
#include "torch/csrc/jit/ir/ir_views.h"
1616
#include "torch/csrc/jit/passes/graph_fuser.h"
1717
#include "torch/csrc/jit/passes/loop_unrolling.h"
@@ -33,9 +33,18 @@ void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph> g, std::vector<t
3333
auto input_size = g->inputs().size();
3434
auto param_it = params.rbegin();
3535
for (int i = input_size - 1; i >= 0; --i) {
36-
if (g->inputs()[i]->type() != c10::TensorType::get() && g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType &&
36+
if (g->inputs()[i]->type() != c10::TensorType::get() &&
37+
g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType &&
3738
g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) {
38-
auto new_constant = torch::jit::tryInsertConstant(*g, *param_it);
39+
auto val = *param_it;
40+
if (val.isTensor()) {
41+
at::Tensor val_tensor = val.toTensor();
42+
if (val_tensor.requires_grad()) {
43+
val_tensor.set_requires_grad(false);
44+
val = val_tensor;
45+
}
46+
}
47+
auto new_constant = torch::jit::tryInsertConstant(*g, val);
3948
++param_it;
4049
if (new_constant) {
4150
g->inputs()[i]->replaceAllUsesWith(*new_constant);

0 commit comments

Comments
 (0)