File tree Expand file tree Collapse file tree 1 file changed +12
-3
lines changed Expand file tree Collapse file tree 1 file changed +12
-3
lines changed Original file line number Diff line number Diff line change 10
10
#include " ATen/core/jit_type.h"
11
11
12
12
#include " torch/csrc/jit/frontend/function_schema_parser.h"
13
- #include " torch/csrc/jit/ir/ir.h"
14
13
#include " torch/csrc/jit/ir/constants.h"
14
+ #include " torch/csrc/jit/ir/ir.h"
15
15
#include " torch/csrc/jit/ir/ir_views.h"
16
16
#include " torch/csrc/jit/passes/graph_fuser.h"
17
17
#include " torch/csrc/jit/passes/loop_unrolling.h"
@@ -33,9 +33,18 @@ void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph> g, std::vector<t
33
33
auto input_size = g->inputs ().size ();
34
34
auto param_it = params.rbegin ();
35
35
for (int i = input_size - 1 ; i >= 0 ; --i) {
36
- if (g->inputs ()[i]->type () != c10::TensorType::get () && g->inputs ()[i]->type ()->kind () != torch::jit::TypeKind::TupleType &&
36
+ if (g->inputs ()[i]->type () != c10::TensorType::get () &&
37
+ g->inputs ()[i]->type ()->kind () != torch::jit::TypeKind::TupleType &&
37
38
g->inputs ()[i]->type ()->kind () != torch::jit::TypeKind::ListType && param_it != params.rend ()) {
38
- auto new_constant = torch::jit::tryInsertConstant (*g, *param_it);
39
+ auto val = *param_it;
40
+ if (val.isTensor ()) {
41
+ at::Tensor val_tensor = val.toTensor ();
42
+ if (val_tensor.requires_grad ()) {
43
+ val_tensor.set_requires_grad (false );
44
+ val = val_tensor;
45
+ }
46
+ }
47
+ auto new_constant = torch::jit::tryInsertConstant (*g, val);
39
48
++param_it;
40
49
if (new_constant) {
41
50
g->inputs ()[i]->replaceAllUsesWith (*new_constant);
You can’t perform that action at this time.
0 commit comments