15
15
from __future__ import annotations
16
16
17
17
import atexit
18
+ import base64
19
+ import json
18
20
import sys
19
21
from abc import ABC , abstractmethod
20
22
from enum import Enum
21
23
from pathlib import Path
22
- from typing import TYPE_CHECKING , ClassVar , NamedTuple
24
+ from typing import TYPE_CHECKING , Any , ClassVar , NamedTuple
23
25
24
26
from typing_extensions import Self
25
27
26
- from .envs import ENV_SOT_COLLECT_INFO
28
+ from .envs import ENV_SOT_COLLECT_INFO , ENV_SOT_SERIALIZE_INFO
27
29
from .utils import Singleton
28
30
29
31
if TYPE_CHECKING :
30
32
import types
31
33
32
34
from .exceptions import BreakGraphReasonBase
33
35
36
+ PREFIX = "<sot>"
37
+ SUFFIX = "</sot>"
38
+ ENCODING = "utf-8"
39
+
34
40
35
41
def try_import_graphviz ():
36
42
try :
@@ -102,7 +108,10 @@ def generate_report(self, info_dict: dict[str, list[InfoBase]]) -> str:
102
108
for info_class_name , info_list in info_dict .items ():
103
109
cls = info_list [0 ].__class__
104
110
report += f"{ info_class_name } ({ cls .SHORT_NAME } ):\n "
105
- report += cls .summary (info_list )
111
+ if ENV_SOT_SERIALIZE_INFO .get ():
112
+ report += cls .json_report (info_list )
113
+ else :
114
+ report += cls .summary (info_list )
106
115
report += "\n "
107
116
return report
108
117
@@ -120,6 +129,22 @@ def __init__(self): ...
120
129
@abstractmethod
121
130
def summary (cls , history : list [Self ]) -> str : ...
122
131
132
+ @classmethod
133
+ def serialize (cls , obj : dict [str :Any ]) -> str :
134
+
135
+ json_data = json .dumps (obj )
136
+ b64_bytes = base64 .b64encode (json_data .encode (ENCODING ))
137
+
138
+ return b64_bytes .decode (ENCODING )
139
+
140
+ @classmethod
141
+ def deserialize (cls , data : bytes | str ) -> dict :
142
+ if isinstance (data , str ):
143
+ data = data .encode (ENCODING )
144
+ json_str = base64 .b64decode (data ).decode (ENCODING )
145
+
146
+ return json .loads (json_str )
147
+
123
148
124
149
class NewSymbolHitRateInfo (InfoBase ):
125
150
SHORT_NAME = "new_symbol_hit_rate"
@@ -154,6 +179,11 @@ def summary(cls, history: list[Self]) -> str:
154
179
summary += f"Hit rate: { hit_count / all_count :.2f} "
155
180
return summary
156
181
182
+ @classmethod
183
+ def json_report (cls , history : list [Self ]) -> str :
184
+ # TODO: need to support serialize the output
185
+ return cls .summary (history )
186
+
157
187
158
188
class SubGraphRelationInfo (InfoBase ):
159
189
SHORT_NAME = "subgraph_relation"
@@ -241,6 +271,11 @@ def to_tensor_node_name(
241
271
dot .render (directory / filename , format = "svg" , cleanup = True )
242
272
return f"Please check { directory / filename } .svg for subgraph relation"
243
273
274
+ @classmethod
275
+ def json_report (cls , history : list [Self ]) -> str :
276
+ # TODO: need to support serialize the output
277
+ return cls .summary (history )
278
+
244
279
245
280
class CompileCountInfo (InfoBase ):
246
281
SHORT_NAME = "compile_count"
@@ -268,6 +303,11 @@ def summary(cls, history: list[Self]) -> str:
268
303
summary = "\n " .join (summary_lines )
269
304
return summary
270
305
306
+ @classmethod
307
+ def json_report (cls , history : list [Self ]) -> str :
308
+ # TODO: need to support serialize the output
309
+ return cls .summary (history )
310
+
271
311
272
312
class BreakGraphReasonInfo (InfoBase ):
273
313
SHORT_NAME = "breakgraph_reason"
@@ -278,17 +318,24 @@ def __init__(self, reason: BreakGraphReasonBase):
278
318
self .reason = reason
279
319
280
320
@classmethod
281
- def summary (cls , history : list [Self ]) -> str :
282
- reason_dict = {}
321
+ def classify (cls , history : list [Self ]) -> str :
322
+ reasons_dict = {}
283
323
284
324
for info in history :
285
325
name = info .reason .__class__ .__name__
286
- if name not in reason_dict :
287
- reason_dict [name ] = []
288
- reason_dict [name ].append (str (info .reason ))
326
+ if name not in reasons_dict :
327
+ reasons_dict [name ] = []
328
+ reasons_dict [name ].append (str (info .reason ))
329
+
330
+ sorted_reasons = list (reasons_dict .items ())
331
+ sorted_reasons .sort (key = lambda x : len (x [1 ]), reverse = True )
289
332
290
- reason_list = list (reason_dict .items ())
291
- reason_list .sort (key = lambda x : len (x [1 ]), reverse = True )
333
+ return reasons_dict , sorted_reasons
334
+
335
+ @classmethod
336
+ def summary (cls , history : list [Self ]) -> str :
337
+
338
+ reason_dict , reason_list = cls .classify (history )
292
339
293
340
return "\n " .join (
294
341
[
@@ -297,6 +344,33 @@ def summary(cls, history: list[Self]) -> str:
297
344
]
298
345
)
299
346
347
+ @classmethod
348
+ def json_report (cls , history : list [Self ]) -> str :
349
+
350
+ reason_dict , sorted_reasons = cls .classify (history )
351
+ reason_dict ["count" ] = {k : len (v ) for k , v in sorted_reasons }
352
+ serialized = cls .serialize ({cls .SHORT_NAME : reason_dict })
353
+
354
+ return f"{ PREFIX } { serialized } { SUFFIX } "
355
+
356
+ @classmethod
357
+ def restore_from_string (cls , serialized : str ) -> list [Self ]:
358
+ # This method is the inverse of json_report
359
+
360
+ from paddle .jit .sot .utils import exceptions
361
+
362
+ history = []
363
+ obj = cls .deserialize (serialized )[cls .SHORT_NAME ]
364
+ obj .pop ("count" )
365
+
366
+ for classname in obj :
367
+
368
+ ReasonClass = getattr (exceptions , classname , None )
369
+ for reason in obj [classname ]:
370
+ history .append (cls (ReasonClass (reason_str = reason )))
371
+
372
+ return history
373
+
300
374
@staticmethod
301
375
def collect_break_graph_reason (reason : BreakGraphReasonBase ):
302
376
if not InfoCollector ().need_collect (BreakGraphReasonInfo ):
@@ -309,7 +383,8 @@ class SubGraphInfo(InfoBase):
309
383
SHORT_NAME = "subgraph_info"
310
384
TYPE = InfoType .STEP_INFO
311
385
312
- def __init__ (self , graph , op_num , sir_name ):
386
+ def __init__ (self , graph : str , op_num : int , sir_name : str ):
387
+ # NOTE: All data should be serializable
313
388
super ().__init__ ()
314
389
self .graph = graph
315
390
self .op_num = op_num
@@ -320,11 +395,12 @@ def __str__(self):
320
395
321
396
@classmethod
322
397
def summary (cls , history : list [Self ]) -> str :
323
-
324
398
num_of_subgraph = len (history )
325
399
sum_of_op_num = sum (item .op_num for item in history )
326
400
327
- need_details = "details" in ENV_SOT_COLLECT_INFO .get ()[cls .SHORT_NAME ]
401
+ need_details = "details" in ENV_SOT_COLLECT_INFO .get ().get (
402
+ cls .SHORT_NAME , []
403
+ )
328
404
329
405
details = ""
330
406
if need_details :
@@ -338,3 +414,59 @@ def summary(cls, history: list[Self]) -> str:
338
414
summary = f"[Number of subgraph]: { num_of_subgraph } [Sum of opnum]: { sum_of_op_num } "
339
415
340
416
return f"{ summary } \n { details } "
417
+
418
+ @classmethod
419
+ def json_report (cls , history : list [Self ]) -> str :
420
+ need_details = "details" in ENV_SOT_COLLECT_INFO .get ().get (
421
+ cls .SHORT_NAME , []
422
+ )
423
+
424
+ aggregated_info_list = []
425
+ for idx , record in enumerate (history ):
426
+ entry_data = {}
427
+
428
+ entry_data ["SIR_name" ] = record .sir_name
429
+ entry_data ["OpNum" ] = record .op_num
430
+ entry_data ["Graph" ] = ""
431
+ if need_details :
432
+ entry_data ["Graph" ] = str (record .graph )
433
+ aggregated_info_list .append (entry_data )
434
+
435
+ serialized = cls .serialize ({cls .SHORT_NAME : aggregated_info_list })
436
+
437
+ return f"{ PREFIX } { serialized } { SUFFIX } "
438
+
439
+ @classmethod
440
+ def restore_from_string (cls , serialized : str ) -> list [Self ]:
441
+ # This method is the inverse of json_report
442
+
443
+ history = []
444
+ obj = cls .deserialize (serialized )[cls .SHORT_NAME ]
445
+
446
+ for entry in obj :
447
+
448
+ history .append (
449
+ SubGraphInfo (
450
+ graph = entry ["Graph" ],
451
+ op_num = entry ["OpNum" ],
452
+ sir_name = entry ["SIR_name" ],
453
+ )
454
+ )
455
+
456
+ return history
457
+
458
+ def __eq__ (self , other ):
459
+
460
+ need_graph_equal = "details" in ENV_SOT_COLLECT_INFO .get ().get (
461
+ self .SHORT_NAME , []
462
+ )
463
+
464
+ graph_equal_or_not = True
465
+ if need_graph_equal :
466
+ graph_equal_or_not = self .graph == other .graph
467
+
468
+ return (
469
+ graph_equal_or_not
470
+ and self .op_num == other .op_num
471
+ and self .sir_name == other .sir_name
472
+ )
0 commit comments