26
26
from tf_agents .trajectories import time_step
27
27
28
28
from compiler_opt .rl import policy_saver
29
-
30
-
31
- # copied from the llvm regalloc generator
32
- def _gen_test_model (outdir : str ):
33
- policy_decision_label = 'index_to_evict'
34
- policy_output_spec = """
35
- [
36
- {
37
- "logging_name": "index_to_evict",
38
- "tensor_spec": {
39
- "name": "StatefulPartitionedCall",
40
- "port": 0,
41
- "type": "int64_t",
42
- "shape": [
43
- 1
44
- ]
45
- }
46
- }
47
- ]
48
- """
49
- per_register_feature_list = ['mask' ]
50
- num_registers = 33
51
-
52
- def get_input_signature ():
53
- """Returns (time_step_spec, action_spec) for LLVM register allocation."""
54
- inputs = dict (
55
- (key , tf .TensorSpec (dtype = tf .int64 , shape = (num_registers ), name = key ))
56
- for key in per_register_feature_list )
57
- return inputs
58
-
59
- module = tf .Module ()
60
- # We have to set this useless variable in order for the TF C API to correctly
61
- # intake it
62
- module .var = tf .Variable (0 , dtype = tf .int64 )
63
-
64
- def action (* inputs ):
65
- result = tf .math .argmax (
66
- tf .cast (inputs [0 ]['mask' ], tf .int32 ), axis = - 1 ) + module .var
67
- return {policy_decision_label : result }
68
-
69
- module .action = tf .function ()(action )
70
- action = {
71
- 'action' : module .action .get_concrete_function (get_input_signature ())
72
- }
73
- tf .saved_model .save (module , outdir , signatures = action )
74
- output_spec_path = os .path .join (outdir , 'output_spec.json' )
75
- with tf .io .gfile .GFile (output_spec_path , 'w' ) as f :
76
- print (f'Writing output spec to { output_spec_path } .' )
77
- f .write (policy_output_spec )
29
+ from compiler_opt .testing import model_test_utils
78
30
79
31
80
32
class PolicySaverTest (tf .test .TestCase ):
@@ -135,7 +87,7 @@ def test_save_policy(self):
135
87
def test_tflite_conversion (self ):
136
88
sm_dir = os .path .join (self .get_temp_dir (), 'saved_model' )
137
89
tflite_dir = os .path .join (self .get_temp_dir (), 'tflite_model' )
138
- _gen_test_model (sm_dir )
90
+ model_test_utils . gen_test_model (sm_dir )
139
91
policy_saver .convert_mlgo_model (sm_dir , tflite_dir )
140
92
self .assertTrue (
141
93
tf .io .gfile .exists (
@@ -148,7 +100,7 @@ def test_policy_serialization(self):
148
100
sm_dir = os .path .join (self .get_temp_dir (), 'model' )
149
101
orig_dir = os .path .join (self .get_temp_dir (), 'orig_model' )
150
102
dest_dir = os .path .join (self .get_temp_dir (), 'dest_model' )
151
- _gen_test_model (sm_dir )
103
+ model_test_utils . gen_test_model (sm_dir )
152
104
policy_saver .convert_mlgo_model (sm_dir , orig_dir )
153
105
154
106
serialized_policy = policy_saver .Policy .from_filesystem (orig_dir )
0 commit comments