File tree Expand file tree Collapse file tree 3 files changed +11
-11
lines changed Expand file tree Collapse file tree 3 files changed +11
-11
lines changed Original file line number Diff line number Diff line change 16
16
17
17
import abc
18
18
import dataclasses
19
- import json
20
19
import os
21
20
import signal
22
21
import subprocess
@@ -50,14 +49,6 @@ class RewardStat:
50
49
moving_average_reward : float
51
50
52
51
53
- class DataClassJSONEncoder (json .JSONEncoder ):
54
-
55
- def default (self , o ):
56
- if dataclasses .is_dataclass (o ):
57
- return dataclasses .asdict (o )
58
- return super ().default (o )
59
-
60
-
61
52
def _overwrite_trajectory_reward (sequence_example : tf .train .SequenceExample ,
62
53
reward : float ) -> tf .train .SequenceExample :
63
54
"""Overwrite the reward in the trace (sequence_example) with the given one.
Original file line number Diff line number Diff line change 14
14
# limitations under the License.
15
15
"""Constants for policy training."""
16
16
17
+ import dataclasses
17
18
import enum
18
19
import gin
20
+ import json
19
21
20
22
BASE_DIR = 'compiler_opt/rl'
21
23
BASE_MODULE_DIR = 'compiler_opt.rl'
@@ -33,3 +35,11 @@ class AgentName(enum.Enum):
33
35
BEHAVIORAL_CLONE = 0
34
36
DQN = 1
35
37
PPO = 2
38
+
39
+
40
+ class DataClassJSONEncoder (json .JSONEncoder ):
41
+
42
+ def default (self , o ):
43
+ if dataclasses .is_dataclass (o ):
44
+ return dataclasses .asdict (o )
45
+ return super ().default (o )
Original file line number Diff line number Diff line change @@ -152,8 +152,7 @@ def sequence_example_iterator_fn(seq_ex: List[str]):
152
152
logging .info ('Last iteration took: %f' , t2 - t1 )
153
153
t1 = t2
154
154
with tf .io .gfile .GFile (reward_stat_map_path , 'w' ) as f :
155
- json .dump (
156
- reward_stat_map , f , cls = compilation_runner .DataClassJSONEncoder )
155
+ json .dump (reward_stat_map , f , cls = constant .DataClassJSONEncoder )
157
156
158
157
policy_path = os .path .join (root_dir , 'policy' ,
159
158
str (llvm_trainer .global_step_numpy ()))
You can’t perform that action at this time.
0 commit comments