Skip to content

Commit 0f90243

Browse files
authored
[SOT] Use tuple InputSpec to avoid type check error (#69853)
1 parent f8c16f4 commit 0f90243

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

python/paddle/jit/sot/opcode_translator/executor/function_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def compile_graph(self, *ret_vars: VariableBase) -> CompileGraphResult:
433433
symbolic_inputs = self._find_tensor_inputs(input_names)
434434
compiled_fn = self.sir_ctx.compile_fn(
435435
statement_ir.name,
436-
[var.meta.to_input_spec() for var in symbolic_inputs],
436+
tuple(var.meta.to_input_spec() for var in symbolic_inputs),
437437
**self._kwargs,
438438
)
439439
return compiled_fn, (statement_ir, symbolic_inputs, symbolic_outputs)

python/paddle/jit/sot/symbolic/compile_cache.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ def amp_cast_inputs(self, args, kwargs):
104104
def graph_size(self):
105105
if self.partial_program is None:
106106
input_spec = convert_meta_to_input_spec(
107-
[self.SIR.symbol_meta_map[symbol] for symbol in self.SIR.inputs]
107+
tuple(
108+
self.SIR.symbol_meta_map[symbol]
109+
for symbol in self.SIR.inputs
110+
)
108111
)
109112
(
110113
self.concrete_program,
@@ -181,7 +184,7 @@ def key_fn(
181184
self,
182185
context: SymbolicTraceContext,
183186
sir_name: str,
184-
input_spec: list[InputSpec],
187+
input_spec: tuple[InputSpec, ...],
185188
**kwargs,
186189
):
187190
"""
@@ -204,7 +207,7 @@ def value_fn(
204207
self,
205208
context: SymbolicTraceContext,
206209
sir_name: str,
207-
input_spec: list[InputSpec],
210+
input_spec: tuple[InputSpec, ...],
208211
**kwargs,
209212
):
210213
"""

python/paddle/jit/sot/symbolic/symbolic_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ def graph_size(self):
156156

157157
return DummyFunc()
158158

159-
def compile_fn(self, sir_name: str, input_spec: list[InputSpec], **kwargs):
159+
def compile_fn(
160+
self, sir_name: str, input_spec: tuple[InputSpec, ...], **kwargs
161+
):
160162
"""
161163
start compile and return the python function, which must can be to_static without errors.
162164
"""

0 commit comments

Comments
 (0)