Skip to content

Commit 2de4534

Browse files
committed
Remove truncated from termination criteria
1 parent 3fec945 commit 2de4534

File tree

9 files changed

+39
-15
lines changed

9 files changed

+39
-15
lines changed

rest_api/gymnasium_envs/classic_control/acrobot_env_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,12 @@ async def step(action: int = Body(...), cidx: int = Body(...)) -> JSONResponse:
147147
observation = [float(val) for val in observation]
148148

149149
step_type = TimeStepType.MID
150-
if terminated or truncated:
150+
if terminated:
151151
step_type = TimeStepType.LAST
152152

153+
if info is not None:
154+
info['truncated'] = truncated
155+
153156
step = TimeStep(observation=observation,
154157
reward=reward,
155158
step_type=step_type,

rest_api/gymnasium_envs/classic_control/cart_pole_env_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,12 @@ async def step(action: int = Body(...), cidx: int = Body(...)) -> JSONResponse:
143143
observation = [float(val) for val in observation]
144144

145145
step_type = TimeStepType.MID
146-
if terminated or truncated:
146+
if terminated:
147147
step_type = TimeStepType.LAST
148148

149+
if info is not None:
150+
info['truncated'] = truncated
151+
149152
step = TimeStep(observation=observation,
150153
reward=reward,
151154
step_type=step_type,

rest_api/gymnasium_envs/classic_control/mountain_car_env_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,12 @@ async def step(action: int = Body(...), cidx: int = Body(...)) -> JSONResponse:
137137
observation = [float(val) for val in observation]
138138

139139
step_type = TimeStepType.MID
140-
if terminated or truncated:
140+
if terminated:
141141
step_type = TimeStepType.LAST
142+
143+
if info is not None:
144+
info['truncated'] = truncated
145+
142146
step = TimeStep(observation=observation,
143147
reward=reward,
144148
step_type=step_type,

rest_api/gymnasium_envs/classic_control/pendulum_env_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,12 @@ async def step(action: float = Body(...), cidx: int = Body(...)) -> JSONResponse
152152
observation = [float(val) for val in observation]
153153

154154
step_type = TimeStepType.MID
155-
if terminated or truncated:
155+
if terminated:
156156
step_type = TimeStepType.LAST
157157

158+
if info is not None:
159+
info['truncated'] = truncated
160+
158161
step = TimeStep(observation=observation,
159162
reward=reward,
160163
step_type=step_type,

rest_api/gymnasium_envs/classic_control/v/acrobot_v_env_api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
acrobot_v_router = APIRouter(prefix="/gymnasium/acrobot-env/v", tags=["Acrobot Vector env API"])
1616

17-
1817
NUM_COPIES = 0
1918
ENV_NAME = "Acrobot"
2019
VECTORIZATION_MODE = 'sync'
@@ -24,7 +23,6 @@
2423
0: None
2524
}
2625

27-
2826
# actions that the environment accepts
2927
ACTIONS_SPACE = {0: "apply -1 torque to the actuated joint",
3028
1: "apply 0 torque to the actuated joint",
@@ -164,7 +162,6 @@ async def reset(seed: int = Body(default=42), cidx: int = Body(...),
164162

165163
@acrobot_v_router.post("/step")
166164
async def step(action: dict[str, list[int]] = Body(title='actions'), cidx: int = Body(...)) -> JSONResponse:
167-
168165
global NUM_COPIES
169166

170167
actions = action['actions']
@@ -191,17 +188,20 @@ async def step(action: dict[str, list[int]] = Body(title='actions'), cidx: int =
191188

192189
# if we truncate or terminate
193190
# set the environment step type to finished
194-
for i, tr in enumerate(truncates):
195-
if tr:
196-
step_types[i] = TimeStepType.LAST
191+
for i, tr in enumerate(terminates):
192+
# if tr:
193+
# step_types[i] = TimeStepType.LAST
197194

198195
if terminates[i]:
199196
step_types[i] = TimeStepType.LAST
200197

198+
for i in truncates:
199+
infos[i]['truncated'] = truncates[i]
200+
201201
step = TimeStepV(observations=observations_ar,
202202
rewards=rewards,
203203
step_types=step_types,
204-
infos=[],
204+
infos=infos,
205205
discounts=[1.0] * NUM_COPIES)
206206

207207
logger.info(f'Step in environment {ENV_NAME} and index {cidx}')

rest_api/gymnasium_envs/toy_text/black_jack_env_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,11 @@ async def step(action: int = Body(...), cidx: int = Body(...)) -> JSONResponse:
133133
observation, reward, terminated, truncated, info = envs[cidx].step(action)
134134

135135
step_type = TimeStepType.MID
136-
if terminated or truncated:
136+
if terminated:
137137
step_type = TimeStepType.LAST
138138

139+
if info is not None:
140+
info['truncated'] = truncated
139141
step = TimeStep(observation=observation,
140142
reward=reward,
141143
step_type=step_type,

rest_api/gymnasium_envs/toy_text/cliffwalking_env_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,11 @@ async def step(action: int = Body(...), cidx: int = Body(...)) -> JSONResponse:
128128
observation, reward, terminated, truncated, info = envs[cidx].step(action)
129129

130130
step_type = TimeStepType.MID
131-
if terminated or truncated:
131+
if terminated:
132132
step_type = TimeStepType.LAST
133+
134+
if info is not None:
135+
info['truncated'] = truncated
133136
step = TimeStep(observation=observation,
134137
reward=reward,
135138
step_type=step_type,

rest_api/gymnasium_envs/toy_text/frozenlake_env_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,12 @@ async def step(action: int = Body(...), cidx: int = Body(...)) -> JSONResponse:
134134
observation, reward, terminated, truncated, info = envs[cidx].step(action)
135135

136136
step_type = TimeStepType.MID
137-
if terminated or truncated:
137+
if terminated:
138138
step_type = TimeStepType.LAST
139139

140+
if info is not None:
141+
info['truncated'] = truncated
142+
140143
step = TimeStep(observation=observation,
141144
reward=reward,
142145
step_type=step_type,

rest_api/gymnasium_envs/toy_text/taxi_env_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,12 @@ async def step(action: int = Body(...), cidx: int = Body(...)) -> JSONResponse:
131131
observation, reward, terminated, truncated, info = envs[cidx].step(action)
132132

133133
step_type = TimeStepType.MID
134-
if terminated or truncated:
134+
if terminated:
135135
step_type = TimeStepType.LAST
136136

137+
if info is not None:
138+
info['truncated'] = truncated
139+
137140
action_mask = info['action_mask']
138141
step = TimeStep(observation=observation,
139142
reward=reward,

0 commit comments

Comments
 (0)