Skip to content

Commit 27a9b82

Browse files
authored
[SOT] Add JSON dumping and Base64 encoding for log output (#71525)
1 parent 2ab13a8 commit 27a9b82

File tree

4 files changed

+268
-14
lines changed

4 files changed

+268
-14
lines changed

python/paddle/jit/sot/symbolic/compile_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def collect_subgraph_info(self, program: Program):
223223

224224
InfoCollector().attach(
225225
SubGraphInfo,
226-
program,
226+
str(program),
227227
self.graph_size(),
228228
self.SIR.name,
229229
)

python/paddle/jit/sot/utils/envs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def parse_parameterized_key(input_str: str) -> dict[str, list[str]]:
141141
"SOT_BREAK_GRAPH_ON_GET_SYMBOLIC_VALUE", False
142142
)
143143
ENV_SOT_COLLECT_INFO = PEP508LikeEnvironmentVariable("SOT_COLLECT_INFO", {})
144+
ENV_SOT_SERIALIZE_INFO = BooleanEnvironmentVariable("SOT_SERIALIZE_INFO", False)
144145
ENV_SOT_FORCE_FALLBACK_SIR_IDS = StringEnvironmentVariable(
145146
"SOT_FORCE_FALLBACK_SIR_IDS", ""
146147
)

python/paddle/jit/sot/utils/info_collector.py

Lines changed: 145 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,28 @@
1515
from __future__ import annotations
1616

1717
import atexit
18+
import base64
19+
import json
1820
import sys
1921
from abc import ABC, abstractmethod
2022
from enum import Enum
2123
from pathlib import Path
22-
from typing import TYPE_CHECKING, ClassVar, NamedTuple
24+
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple
2325

2426
from typing_extensions import Self
2527

26-
from .envs import ENV_SOT_COLLECT_INFO
28+
from .envs import ENV_SOT_COLLECT_INFO, ENV_SOT_SERIALIZE_INFO
2729
from .utils import Singleton
2830

2931
if TYPE_CHECKING:
3032
import types
3133

3234
from .exceptions import BreakGraphReasonBase
3335

36+
PREFIX = "<sot>"
37+
SUFFIX = "</sot>"
38+
ENCODING = "utf-8"
39+
3440

3541
def try_import_graphviz():
3642
try:
@@ -102,7 +108,10 @@ def generate_report(self, info_dict: dict[str, list[InfoBase]]) -> str:
102108
for info_class_name, info_list in info_dict.items():
103109
cls = info_list[0].__class__
104110
report += f"{info_class_name} ({cls.SHORT_NAME}):\n"
105-
report += cls.summary(info_list)
111+
if ENV_SOT_SERIALIZE_INFO.get():
112+
report += cls.json_report(info_list)
113+
else:
114+
report += cls.summary(info_list)
106115
report += "\n"
107116
return report
108117

@@ -120,6 +129,22 @@ def __init__(self): ...
120129
@abstractmethod
121130
def summary(cls, history: list[Self]) -> str: ...
122131

132+
@classmethod
133+
def serialize(cls, obj: dict[str:Any]) -> str:
134+
135+
json_data = json.dumps(obj)
136+
b64_bytes = base64.b64encode(json_data.encode(ENCODING))
137+
138+
return b64_bytes.decode(ENCODING)
139+
140+
@classmethod
141+
def deserialize(cls, data: bytes | str) -> dict:
142+
if isinstance(data, str):
143+
data = data.encode(ENCODING)
144+
json_str = base64.b64decode(data).decode(ENCODING)
145+
146+
return json.loads(json_str)
147+
123148

124149
class NewSymbolHitRateInfo(InfoBase):
125150
SHORT_NAME = "new_symbol_hit_rate"
@@ -154,6 +179,11 @@ def summary(cls, history: list[Self]) -> str:
154179
summary += f"Hit rate: {hit_count / all_count:.2f}"
155180
return summary
156181

182+
@classmethod
183+
def json_report(cls, history: list[Self]) -> str:
184+
# TODO: need to support serialize the output
185+
return cls.summary(history)
186+
157187

158188
class SubGraphRelationInfo(InfoBase):
159189
SHORT_NAME = "subgraph_relation"
@@ -241,6 +271,11 @@ def to_tensor_node_name(
241271
dot.render(directory / filename, format="svg", cleanup=True)
242272
return f"Please check {directory / filename}.svg for subgraph relation"
243273

274+
@classmethod
275+
def json_report(cls, history: list[Self]) -> str:
276+
# TODO: need to support serialize the output
277+
return cls.summary(history)
278+
244279

245280
class CompileCountInfo(InfoBase):
246281
SHORT_NAME = "compile_count"
@@ -268,6 +303,11 @@ def summary(cls, history: list[Self]) -> str:
268303
summary = "\n".join(summary_lines)
269304
return summary
270305

306+
@classmethod
307+
def json_report(cls, history: list[Self]) -> str:
308+
# TODO: need to support serialize the output
309+
return cls.summary(history)
310+
271311

272312
class BreakGraphReasonInfo(InfoBase):
273313
SHORT_NAME = "breakgraph_reason"
@@ -278,17 +318,24 @@ def __init__(self, reason: BreakGraphReasonBase):
278318
self.reason = reason
279319

280320
@classmethod
281-
def summary(cls, history: list[Self]) -> str:
282-
reason_dict = {}
321+
def classify(cls, history: list[Self]) -> str:
322+
reasons_dict = {}
283323

284324
for info in history:
285325
name = info.reason.__class__.__name__
286-
if name not in reason_dict:
287-
reason_dict[name] = []
288-
reason_dict[name].append(str(info.reason))
326+
if name not in reasons_dict:
327+
reasons_dict[name] = []
328+
reasons_dict[name].append(str(info.reason))
329+
330+
sorted_reasons = list(reasons_dict.items())
331+
sorted_reasons.sort(key=lambda x: len(x[1]), reverse=True)
289332

290-
reason_list = list(reason_dict.items())
291-
reason_list.sort(key=lambda x: len(x[1]), reverse=True)
333+
return reasons_dict, sorted_reasons
334+
335+
@classmethod
336+
def summary(cls, history: list[Self]) -> str:
337+
338+
reason_dict, reason_list = cls.classify(history)
292339

293340
return "\n".join(
294341
[
@@ -297,6 +344,33 @@ def summary(cls, history: list[Self]) -> str:
297344
]
298345
)
299346

347+
@classmethod
348+
def json_report(cls, history: list[Self]) -> str:
349+
350+
reason_dict, sorted_reasons = cls.classify(history)
351+
reason_dict["count"] = {k: len(v) for k, v in sorted_reasons}
352+
serialized = cls.serialize({cls.SHORT_NAME: reason_dict})
353+
354+
return f"{PREFIX}{serialized}{SUFFIX}"
355+
356+
@classmethod
357+
def restore_from_string(cls, serialized: str) -> list[Self]:
358+
# This method is the inverse of json_report
359+
360+
from paddle.jit.sot.utils import exceptions
361+
362+
history = []
363+
obj = cls.deserialize(serialized)[cls.SHORT_NAME]
364+
obj.pop("count")
365+
366+
for classname in obj:
367+
368+
ReasonClass = getattr(exceptions, classname, None)
369+
for reason in obj[classname]:
370+
history.append(cls(ReasonClass(reason_str=reason)))
371+
372+
return history
373+
300374
@staticmethod
301375
def collect_break_graph_reason(reason: BreakGraphReasonBase):
302376
if not InfoCollector().need_collect(BreakGraphReasonInfo):
@@ -309,7 +383,8 @@ class SubGraphInfo(InfoBase):
309383
SHORT_NAME = "subgraph_info"
310384
TYPE = InfoType.STEP_INFO
311385

312-
def __init__(self, graph, op_num, sir_name):
386+
def __init__(self, graph: str, op_num: int, sir_name: str):
387+
# NOTE: All data should be serializable
313388
super().__init__()
314389
self.graph = graph
315390
self.op_num = op_num
@@ -320,11 +395,12 @@ def __str__(self):
320395

321396
@classmethod
322397
def summary(cls, history: list[Self]) -> str:
323-
324398
num_of_subgraph = len(history)
325399
sum_of_op_num = sum(item.op_num for item in history)
326400

327-
need_details = "details" in ENV_SOT_COLLECT_INFO.get()[cls.SHORT_NAME]
401+
need_details = "details" in ENV_SOT_COLLECT_INFO.get().get(
402+
cls.SHORT_NAME, []
403+
)
328404

329405
details = ""
330406
if need_details:
@@ -338,3 +414,59 @@ def summary(cls, history: list[Self]) -> str:
338414
summary = f"[Number of subgraph]: {num_of_subgraph} [Sum of opnum]: {sum_of_op_num}"
339415

340416
return f"{summary}\n{details}"
417+
418+
@classmethod
419+
def json_report(cls, history: list[Self]) -> str:
420+
need_details = "details" in ENV_SOT_COLLECT_INFO.get().get(
421+
cls.SHORT_NAME, []
422+
)
423+
424+
aggregated_info_list = []
425+
for idx, record in enumerate(history):
426+
entry_data = {}
427+
428+
entry_data["SIR_name"] = record.sir_name
429+
entry_data["OpNum"] = record.op_num
430+
entry_data["Graph"] = ""
431+
if need_details:
432+
entry_data["Graph"] = str(record.graph)
433+
aggregated_info_list.append(entry_data)
434+
435+
serialized = cls.serialize({cls.SHORT_NAME: aggregated_info_list})
436+
437+
return f"{PREFIX}{serialized}{SUFFIX}"
438+
439+
@classmethod
440+
def restore_from_string(cls, serialized: str) -> list[Self]:
441+
# This method is the inverse of json_report
442+
443+
history = []
444+
obj = cls.deserialize(serialized)[cls.SHORT_NAME]
445+
446+
for entry in obj:
447+
448+
history.append(
449+
SubGraphInfo(
450+
graph=entry["Graph"],
451+
op_num=entry["OpNum"],
452+
sir_name=entry["SIR_name"],
453+
)
454+
)
455+
456+
return history
457+
458+
def __eq__(self, other):
459+
460+
need_graph_equal = "details" in ENV_SOT_COLLECT_INFO.get().get(
461+
self.SHORT_NAME, []
462+
)
463+
464+
graph_equal_or_not = True
465+
if need_graph_equal:
466+
graph_equal_or_not = self.graph == other.graph
467+
468+
return (
469+
graph_equal_or_not
470+
and self.op_num == other.op_num
471+
and self.sir_name == other.sir_name
472+
)

test/sot/test_info_collect.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import random
17+
import string
18+
import unittest
19+
20+
from test_case_base import TestCaseBase
21+
22+
from paddle.jit.sot.utils.exceptions import (
23+
BuiltinFunctionBreak,
24+
DataDependencyControlFlowBreak,
25+
DataDependencyDynamicShapeBreak,
26+
DataDependencyOperationBreak,
27+
DygraphInconsistentWithStaticBreak,
28+
FallbackInlineCallBreak,
29+
InferMetaBreak,
30+
InlineCallBreak,
31+
OtherInlineCallBreak,
32+
PsdbBreakReason,
33+
SideEffectBreak,
34+
UnsupportedIteratorBreak,
35+
UnsupportedPaddleAPIBreak,
36+
)
37+
from paddle.jit.sot.utils.info_collector import (
38+
BreakGraphReasonInfo,
39+
InfoBase,
40+
SubGraphInfo,
41+
)
42+
43+
generate_random_string = lambda N: ''.join(
44+
random.choices(string.ascii_uppercase + string.digits, k=N)
45+
)
46+
47+
48+
class TestSerialize(TestCaseBase):
49+
def test_case(self):
50+
x_dict = {
51+
'a': 'b',
52+
'c': 'd',
53+
}
54+
55+
x_str = InfoBase.serialize(x_dict)
56+
y_dict = InfoBase.deserialize(x_str)
57+
58+
self.assertEqual(x_dict, y_dict)
59+
60+
61+
class TestBreakGraphReasonInfo(TestCaseBase):
62+
def test_case(self):
63+
history = [
64+
BreakGraphReasonInfo(
65+
BreakReasonClass(
66+
reason_str=generate_random_string(random.randint(1, 5))
67+
)
68+
)
69+
for BreakReasonClass in [
70+
FallbackInlineCallBreak,
71+
DataDependencyControlFlowBreak,
72+
DataDependencyDynamicShapeBreak,
73+
DataDependencyOperationBreak,
74+
UnsupportedPaddleAPIBreak,
75+
BuiltinFunctionBreak,
76+
SideEffectBreak,
77+
UnsupportedIteratorBreak,
78+
InlineCallBreak,
79+
OtherInlineCallBreak,
80+
DygraphInconsistentWithStaticBreak,
81+
PsdbBreakReason,
82+
InferMetaBreak,
83+
]
84+
]
85+
86+
serialized = BreakGraphReasonInfo.json_report(history)
87+
deserialized = BreakGraphReasonInfo.restore_from_string(
88+
serialized[5:-6] # remove `<sot>` and `</sot>`
89+
) # `removeprefix` & `removesuffix` are only available from python3.9
90+
91+
origin_reasons_dict, _ = BreakGraphReasonInfo.classify(history)
92+
origin_reasons2count = {
93+
k: len(v) for k, v in origin_reasons_dict.items()
94+
}
95+
96+
new_reasons_dict, _ = BreakGraphReasonInfo.classify(deserialized)
97+
new_reasons2count = {k: len(v) for k, v in new_reasons_dict.items()}
98+
99+
self.assertEqual(origin_reasons2count, new_reasons2count)
100+
101+
102+
class TestSubGraphInfo(TestCaseBase):
103+
def test_case(self):
104+
history = [
105+
SubGraphInfo(
106+
generate_random_string(random.randint(1, 5)),
107+
random.randint(0, 20),
108+
generate_random_string(random.randint(1, 5)),
109+
),
110+
] * 10
111+
112+
serialized = SubGraphInfo.json_report(history)
113+
deserialized = SubGraphInfo.restore_from_string(
114+
serialized[5:-6] # remove `<sot>` and `</sot>`
115+
) # `removeprefix` & `removesuffix` are only available from python3.9
116+
117+
self.assertEqual(history, deserialized)
118+
119+
120+
if __name__ == "__main__":
121+
unittest.main()

0 commit comments

Comments
 (0)