diff --git a/compiler_opt/rl/env.py b/compiler_opt/rl/env.py index 3a04b90e..229c1b67 100644 --- a/compiler_opt/rl/env.py +++ b/compiler_opt/rl/env.py @@ -179,7 +179,7 @@ def _get_step_type() -> StepType: tv_dict = {} for fv in obs.feature_values: array = fv.to_numpy() - tv_dict[fv.spec.name] = np.reshape(array, newshape=fv.spec.shape) + tv_dict[fv.spec.name] = np.reshape(array, fv.spec.shape) return TimeStep( obs=tv_dict, reward={obs.context: obs.score} if obs.score else None,