Skip to content

Commit b04bef6

Browse files
authored
move DataClassJSONEncoder to a more common place for future reuse (#144)
* test model test * yapf * pylint * move DataClassJSONEncoder to a common place so that we can re-use it * yapf
1 parent 7c90426 commit b04bef6

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

compiler_opt/rl/compilation_runner.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import abc
1818
import dataclasses
19-
import json
2019
import os
2120
import signal
2221
import subprocess
@@ -50,14 +49,6 @@ class RewardStat:
5049
moving_average_reward: float
5150

5251

53-
class DataClassJSONEncoder(json.JSONEncoder):
54-
55-
def default(self, o):
56-
if dataclasses.is_dataclass(o):
57-
return dataclasses.asdict(o)
58-
return super().default(o)
59-
60-
6152
def _overwrite_trajectory_reward(sequence_example: tf.train.SequenceExample,
6253
reward: float) -> tf.train.SequenceExample:
6354
"""Overwrite the reward in the trace (sequence_example) with the given one.

compiler_opt/rl/constant.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# limitations under the License.
1515
"""Constants for policy training."""
1616

17+
import dataclasses
1718
import enum
1819
import gin
20+
import json
1921

2022
BASE_DIR = 'compiler_opt/rl'
2123
BASE_MODULE_DIR = 'compiler_opt.rl'
@@ -33,3 +35,11 @@ class AgentName(enum.Enum):
3335
BEHAVIORAL_CLONE = 0
3436
DQN = 1
3537
PPO = 2
38+
39+
40+
class DataClassJSONEncoder(json.JSONEncoder):
41+
42+
def default(self, o):
43+
if dataclasses.is_dataclass(o):
44+
return dataclasses.asdict(o)
45+
return super().default(o)

compiler_opt/rl/train_locally.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ def sequence_example_iterator_fn(seq_ex: List[str]):
152152
logging.info('Last iteration took: %f', t2 - t1)
153153
t1 = t2
154154
with tf.io.gfile.GFile(reward_stat_map_path, 'w') as f:
155-
json.dump(
156-
reward_stat_map, f, cls=compilation_runner.DataClassJSONEncoder)
155+
json.dump(reward_stat_map, f, cls=constant.DataClassJSONEncoder)
157156

158157
policy_path = os.path.join(root_dir, 'policy',
159158
str(llvm_trainer.global_step_numpy()))

0 commit comments

Comments
 (0)