Skip to content

Commit 28a0353

Browse files
Refactor gen_test_model to separate module
This patch factors the _gen_test_model function in policy_saver_test into a separate module in a new testing subdirectory to facilitate reuse across tests. This is intended to be used in the regalloc_trace_worker test which has to create a test model to test that everything works with TFLite. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #414
1 parent f179ce1 commit 28a0353

File tree

3 files changed

+85
-51
lines changed

3 files changed

+85
-51
lines changed

compiler_opt/rl/policy_saver_test.py

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -26,55 +26,7 @@
2626
from tf_agents.trajectories import time_step
2727

2828
from compiler_opt.rl import policy_saver
29-
30-
31-
# copied from the llvm regalloc generator
32-
def _gen_test_model(outdir: str):
33-
policy_decision_label = 'index_to_evict'
34-
policy_output_spec = """
35-
[
36-
{
37-
"logging_name": "index_to_evict",
38-
"tensor_spec": {
39-
"name": "StatefulPartitionedCall",
40-
"port": 0,
41-
"type": "int64_t",
42-
"shape": [
43-
1
44-
]
45-
}
46-
}
47-
]
48-
"""
49-
per_register_feature_list = ['mask']
50-
num_registers = 33
51-
52-
def get_input_signature():
53-
"""Returns (time_step_spec, action_spec) for LLVM register allocation."""
54-
inputs = dict(
55-
(key, tf.TensorSpec(dtype=tf.int64, shape=(num_registers), name=key))
56-
for key in per_register_feature_list)
57-
return inputs
58-
59-
module = tf.Module()
60-
# We have to set this useless variable in order for the TF C API to correctly
61-
# intake it
62-
module.var = tf.Variable(0, dtype=tf.int64)
63-
64-
def action(*inputs):
65-
result = tf.math.argmax(
66-
tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var
67-
return {policy_decision_label: result}
68-
69-
module.action = tf.function()(action)
70-
action = {
71-
'action': module.action.get_concrete_function(get_input_signature())
72-
}
73-
tf.saved_model.save(module, outdir, signatures=action)
74-
output_spec_path = os.path.join(outdir, 'output_spec.json')
75-
with tf.io.gfile.GFile(output_spec_path, 'w') as f:
76-
print(f'Writing output spec to {output_spec_path}.')
77-
f.write(policy_output_spec)
29+
from compiler_opt.testing import model_test_utils
7830

7931

8032
class PolicySaverTest(tf.test.TestCase):
@@ -135,7 +87,7 @@ def test_save_policy(self):
13587
def test_tflite_conversion(self):
13688
sm_dir = os.path.join(self.get_temp_dir(), 'saved_model')
13789
tflite_dir = os.path.join(self.get_temp_dir(), 'tflite_model')
138-
_gen_test_model(sm_dir)
90+
model_test_utils.gen_test_model(sm_dir)
13991
policy_saver.convert_mlgo_model(sm_dir, tflite_dir)
14092
self.assertTrue(
14193
tf.io.gfile.exists(
@@ -148,7 +100,7 @@ def test_policy_serialization(self):
148100
sm_dir = os.path.join(self.get_temp_dir(), 'model')
149101
orig_dir = os.path.join(self.get_temp_dir(), 'orig_model')
150102
dest_dir = os.path.join(self.get_temp_dir(), 'dest_model')
151-
_gen_test_model(sm_dir)
103+
model_test_utils.gen_test_model(sm_dir)
152104
policy_saver.convert_mlgo_model(sm_dir, orig_dir)
153105

154106
serialized_policy = policy_saver.Policy.from_filesystem(orig_dir)

compiler_opt/testing/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Utilities for running tests that involve tensorflow model.s"""
16+
17+
import os
18+
19+
import tensorflow as tf
20+
21+
22+
# copied from the llvm regalloc generator
23+
def gen_test_model(outdir: str):
24+
policy_decision_label = 'index_to_evict'
25+
policy_output_spec = """
26+
[
27+
{
28+
"logging_name": "index_to_evict",
29+
"tensor_spec": {
30+
"name": "StatefulPartitionedCall",
31+
"port": 0,
32+
"type": "int64_t",
33+
"shape": [
34+
1
35+
]
36+
}
37+
}
38+
]
39+
"""
40+
per_register_feature_list = ['mask']
41+
num_registers = 33
42+
43+
def get_input_signature():
44+
"""Returns (time_step_spec, action_spec) for LLVM register allocation."""
45+
inputs = dict(
46+
(key, tf.TensorSpec(dtype=tf.int64, shape=(num_registers), name=key))
47+
for key in per_register_feature_list)
48+
return inputs
49+
50+
module = tf.Module()
51+
# We have to set this useless variable in order for the TF C API to correctly
52+
# intake it
53+
module.var = tf.Variable(0, dtype=tf.int64)
54+
55+
def action(*inputs):
56+
result = tf.math.argmax(
57+
tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var
58+
return {policy_decision_label: result}
59+
60+
module.action = tf.function()(action)
61+
action = {
62+
'action': module.action.get_concrete_function(get_input_signature())
63+
}
64+
tf.saved_model.save(module, outdir, signatures=action)
65+
output_spec_path = os.path.join(outdir, 'output_spec.json')
66+
with tf.io.gfile.GFile(output_spec_path, 'w') as f:
67+
print(f'Writing output spec to {output_spec_path}.')
68+
f.write(policy_output_spec)

0 commit comments

Comments
 (0)