Skip to content

Commit a525b3e

Browse files
committed
Handle 3.8->3.9 unittest differences (#38)
1 parent 87104f3 commit a525b3e

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

compiler_opt/rl/data_collector_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Tests for data_collector."""
1616

1717
# pylint: disable=protected-access
18-
18+
import sys
1919
from unittest import mock
2020

2121
from absl.testing import absltest
@@ -28,7 +28,12 @@ class DataCollectorTest(absltest.TestCase):
2828
def test_build_distribution_monitor(self):
2929
data = [3, 2, 1]
3030
monitor_dict = data_collector.build_distribution_monitor(data)
31-
self.assertEqual(monitor_dict, monitor_dict | {'mean': 2, 'p_0.1': 1})
31+
reference_dict = {'mean': 2, 'p_0.1': 1}
32+
# Issue #38
33+
if sys.version_info.minor >= 9:
34+
self.assertEqual(monitor_dict, monitor_dict | reference_dict)
35+
else:
36+
self.assertEqual(monitor_dict, {**monitor_dict, **reference_dict})
3237

3338
@mock.patch('time.time')
3439
def test_early_exit(self, mock_time):

compiler_opt/rl/data_reader.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,16 +165,13 @@ def create_sequence_example_dataset_fn(
165165
def _sequence_example_dataset_fn(sequence_examples):
166166
# Data collector returns empty strings for corner cases, filter them out
167167
# here.
168-
dataset = (tf.data.Dataset.from_tensor_slices(sequence_examples)
169-
.filter(lambda string: tf.strings.length(string) > 0)
170-
.map(parser_fn)
171-
.filter(lambda traj: tf.size(traj.reward) > 2)
172-
.unbatch()
173-
.batch(train_sequence_length, drop_remainder=True)
174-
.cache()
175-
.shuffle(trajectory_shuffle_buffer_size)
176-
.batch(batch_size, drop_remainder=True)
177-
)
168+
dataset = (
169+
tf.data.Dataset.from_tensor_slices(sequence_examples).filter(
170+
lambda string: tf.strings.length(string) > 0).map(parser_fn).filter(
171+
lambda traj: tf.size(traj.reward) > 2).unbatch().batch(
172+
train_sequence_length, drop_remainder=True).cache().shuffle(
173+
trajectory_shuffle_buffer_size).batch(
174+
batch_size, drop_remainder=True))
178175
return dataset
179176

180177
return _sequence_example_dataset_fn

compiler_opt/rl/local_data_collector_test.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import multiprocessing as mp
2020
import string
2121
import subprocess
22+
import sys
2223
from unittest import mock
2324

2425
import tensorflow as tf
@@ -147,7 +148,15 @@ def _test_iterator_fn(data_list):
147148
'total_trajectory_length': 18,
148149
}
149150
}
150-
self.assertEqual(monitor_dict, monitor_dict | expected_monitor_dict_subset)
151+
# Issue #38
152+
if sys.version_info.minor >= 9:
153+
self.assertEqual(monitor_dict,
154+
monitor_dict | expected_monitor_dict_subset)
155+
else:
156+
self.assertEqual(monitor_dict, {
157+
**monitor_dict,
158+
**expected_monitor_dict_subset
159+
})
151160

152161
data_iterator, monitor_dict = collector.collect_data(policy_path='policy')
153162
data = list(data_iterator)
@@ -158,7 +167,15 @@ def _test_iterator_fn(data_list):
158167
'total_trajectory_length': 18,
159168
}
160169
}
161-
self.assertEqual(monitor_dict, monitor_dict | expected_monitor_dict_subset)
170+
# Issue #38
171+
if sys.version_info.minor >= 9:
172+
self.assertEqual(monitor_dict,
173+
monitor_dict | expected_monitor_dict_subset)
174+
else:
175+
self.assertEqual(monitor_dict, {
176+
**monitor_dict,
177+
**expected_monitor_dict_subset
178+
})
162179

163180
collector.close_pool()
164181

0 commit comments

Comments
 (0)