|
27 | 27 | from compiler_opt.rl import policy_saver
|
28 | 28 |
|
29 | 29 |
|
| 30 | +# copied from the llvm regalloc generator |
| 31 | +def _gen_test_model(outdir: str): |
| 32 | + policy_decision_label = 'index_to_evict' |
| 33 | + policy_output_spec = """ |
| 34 | + [ |
| 35 | + { |
| 36 | + "logging_name": "index_to_evict", |
| 37 | + "tensor_spec": { |
| 38 | + "name": "StatefulPartitionedCall", |
| 39 | + "port": 0, |
| 40 | + "type": "int64_t", |
| 41 | + "shape": [ |
| 42 | + 1 |
| 43 | + ] |
| 44 | + } |
| 45 | + } |
| 46 | + ] |
| 47 | + """ |
| 48 | + per_register_feature_list = ['mask'] |
| 49 | + num_registers = 33 |
| 50 | + |
| 51 | + def get_input_signature(): |
| 52 | + """Returns (time_step_spec, action_spec) for LLVM register allocation.""" |
| 53 | + inputs = dict( |
| 54 | + (key, tf.TensorSpec(dtype=tf.int64, shape=(num_registers), name=key)) |
| 55 | + for key in per_register_feature_list) |
| 56 | + return inputs |
| 57 | + |
| 58 | + module = tf.Module() |
| 59 | + # We have to set this useless variable in order for the TF C API to correctly |
| 60 | + # intake it |
| 61 | + module.var = tf.Variable(0, dtype=tf.int64) |
| 62 | + |
| 63 | + def action(*inputs): |
| 64 | + result = tf.math.argmax( |
| 65 | + tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var |
| 66 | + return {policy_decision_label: result} |
| 67 | + |
| 68 | + module.action = tf.function()(action) |
| 69 | + action = { |
| 70 | + 'action': module.action.get_concrete_function(get_input_signature()) |
| 71 | + } |
| 72 | + tf.saved_model.save(module, outdir, signatures=action) |
| 73 | + output_spec_path = os.path.join(outdir, 'output_spec.json') |
| 74 | + with tf.io.gfile.GFile(output_spec_path, 'w') as f: |
| 75 | + print(f'Writing output spec to {output_spec_path}.') |
| 76 | + f.write(policy_output_spec) |
| 77 | + |
| 78 | + |
30 | 79 | class PolicySaverTest(tf.test.TestCase):
|
31 | 80 |
|
32 | 81 | def setUp(self):
|
@@ -80,6 +129,18 @@ def test_save_policy(self):
|
80 | 129 | }
|
81 | 130 | }], json.loads(tf.io.gfile.GFile(output_signature_fn).read()))
|
82 | 131 |
|
| 132 | + def test_tflite_conversion(self): |
| 133 | + sm_dir = os.path.join(self.get_temp_dir(), 'saved_model') |
| 134 | + tflite_dir = os.path.join(self.get_temp_dir(), 'tflite_model') |
| 135 | + _gen_test_model(sm_dir) |
| 136 | + policy_saver.convert_mlgo_model(sm_dir, tflite_dir) |
| 137 | + self.assertTrue( |
| 138 | + tf.io.gfile.exists( |
| 139 | + os.path.join(tflite_dir, policy_saver.TFLITE_MODEL_NAME))) |
| 140 | + self.assertTrue( |
| 141 | + tf.io.gfile.exists( |
| 142 | + os.path.join(tflite_dir, policy_saver.OUTPUT_SIGNATURE))) |
| 143 | + |
83 | 144 |
|
84 | 145 | if __name__ == '__main__':
|
85 | 146 | tf.test.main()
|
0 commit comments