Skip to content

Commit 7928bb7

Browse files
authored
[SOT][CINN] Add flag to disable SOT fallback when compile time exceed compile time limit (PaddlePaddle#76386)
1 parent d1b3ba4 commit 7928bb7

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

python/paddle/jit/sot/opcode_translator/executor/executor_cache.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...psdb import NO_FALLBACK_CODES
2626
from ...utils import (
2727
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
28+
ENV_SOT_ENABLE_COMPILE_TIME_LIMIT,
2829
ENV_SOT_ENABLE_GUARD_TREE,
2930
ENV_SOT_ENABLE_STRICT_GUARD_CHECK,
3031
ENV_SOT_UNSAFE_CACHE_FASTPATH,
@@ -183,6 +184,7 @@ def lookup(
183184
enable_strict_guard = ENV_SOT_ENABLE_STRICT_GUARD_CHECK.get()
184185
enable_guard_tree = ENV_SOT_ENABLE_GUARD_TREE.get()
185186
enable_unsafe_cache_fastpath = ENV_SOT_UNSAFE_CACHE_FASTPATH.get()
187+
enable_compile_time_limit = ENV_SOT_ENABLE_COMPILE_TIME_LIMIT.get()
186188

187189
if enable_unsafe_cache_fastpath and (
188190
self.is_fastpath_threshold_reached(code)
@@ -205,13 +207,19 @@ def lookup(
205207
return guarded_fns[cache_index][0]
206208
else:
207209
log(2, "[Cache] all guards missed (guard tree mode)\n")
208-
if compile_time_for_code >= self.MAX_COMPILE_TIME_PER_CODE:
210+
if (
211+
enable_compile_time_limit
212+
and compile_time_for_code >= self.MAX_COMPILE_TIME_PER_CODE
213+
):
209214
log(
210215
2,
211216
"[Cache] Exceed max compile time per code, skip it\n",
212217
)
213218
return CustomCode(None, False)
214-
if compile_time_total >= self.MAX_COMPILE_TIME_TOTAL:
219+
if (
220+
enable_compile_time_limit
221+
and compile_time_total >= self.MAX_COMPILE_TIME_TOTAL
222+
):
215223
log_once(
216224
f"[SOT] Current total compile time is {compile_time_total}, exceed max compile time total {self.MAX_COMPILE_TIME_TOTAL}, fallback new function to dygraph"
217225
)
@@ -307,10 +315,16 @@ def lookup(
307315
)
308316

309317
log(2, "[Cache] all guards missed\n")
310-
if compile_time_for_code >= self.MAX_COMPILE_TIME_PER_CODE:
318+
if (
319+
enable_compile_time_limit
320+
and compile_time_for_code >= self.MAX_COMPILE_TIME_PER_CODE
321+
):
311322
log(2, "[Cache] Exceed max compile time per code, skip it\n")
312323
return CustomCode(None, False)
313-
if compile_time_total >= self.MAX_COMPILE_TIME_TOTAL:
324+
if (
325+
enable_compile_time_limit
326+
and compile_time_total >= self.MAX_COMPILE_TIME_TOTAL
327+
):
314328
log_once(
315329
f"[SOT] Current compile time total is {compile_time_total}, exceed max compile time total {self.MAX_COMPILE_TIME_TOTAL}, fallback new function to dygraph"
316330
)

python/paddle/jit/sot/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ENV_SOT_CE_DEBUG_MODE,
2020
ENV_SOT_COLLECT_INFO,
2121
ENV_SOT_ENABLE_0_SIZE_FALLBACK,
22+
ENV_SOT_ENABLE_COMPILE_TIME_LIMIT,
2223
ENV_SOT_ENABLE_FASTER_GUARD,
2324
ENV_SOT_ENABLE_GUARD_TREE,
2425
ENV_SOT_ENABLE_STRICT_GUARD_CHECK,

python/paddle/jit/sot/utils/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ def parse_parameterized_key(input_str: str) -> dict[str, list[str]]:
157157
ENV_SOT_SPECIALIZED_DIM_NUMBERS = StringEnvironmentVariable(
158158
"SOT_SPECIALIZED_DIM_NUMBERS", "0"
159159
)
160+
ENV_SOT_ENABLE_COMPILE_TIME_LIMIT = BooleanEnvironmentVariable(
161+
"SOT_ENABLE_COMPILE_TIME_LIMIT", True
162+
)
160163

161164

162165
def update_ce_flags():

0 commit comments

Comments
 (0)