Skip to content

Commit db2f276

Browse files
committed
fix IW training on multiple GPUs, removed unused param
1 parent 635d55f commit db2f276

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

pufferlib/ocean/impulse_wars/scripted_agent.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ void handleWallProximity(iwEnv *e, const droneEntity *drone, const wallEntity *w
343343

344344
// charge burst until we're close enough to a death wall to burst off
345345
// of it
346-
void wallBurst(iwEnv *e, const droneEntity *drone, const float speed, const float distance, agentActions *actions) {
346+
void wallBurst(iwEnv *e, const droneEntity *drone, const float distance, agentActions *actions) {
347347
if (distance < DRONE_BURST_RADIUS_MIN) {
348348
scriptedAgentBurst(drone, actions);
349349
return;
@@ -432,7 +432,7 @@ agentActions scriptedAgentActions(iwEnv *e, droneEntity *drone) {
432432
if (entityTypeIsWall(ent->type) && ent->type == DEATH_WALL_ENTITY) {
433433
actions.brake = true;
434434
if (drone->shield == NULL) {
435-
wallBurst(e, drone, droneSpeed, b2Distance(drone->pos, ctx.point), &actions);
435+
wallBurst(e, drone, b2Distance(drone->pos, ctx.point), &actions);
436436
}
437437

438438
const b2Vec2 droneDirection = b2Normalize(drone->velocity);

pufferlib/ocean/torch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,6 @@ def __init__(
598598
num_drones: int = 2,
599599
continuous: bool = False,
600600
is_training: bool = True,
601-
device: str = "cuda",
602601
**kwargs,
603602
):
604603
super().__init__()
@@ -616,13 +615,13 @@ def __init__(
616615
+ [self.obsInfo.wallTypes + 1] * self.obsInfo.numFloatingWallObs
617616
+ [self.numDrones + 1] * self.obsInfo.numProjectileObs,
618617
)
619-
discreteOffsets = torch.tensor([0] + list(np.cumsum(self.discreteFactors)[:-1]), device=device).view(
618+
discreteOffsets = torch.tensor([0] + list(np.cumsum(self.discreteFactors)[:-1])).view(
620619
1, -1
621620
)
622621
self.register_buffer("discreteOffsets", discreteOffsets, persistent=False)
623622
self.discreteMultihotDim = self.discreteFactors.sum()
624623

625-
multihotBuffer = torch.zeros(batch_size, self.discreteMultihotDim, device=device)
624+
multihotBuffer = torch.zeros(batch_size, self.discreteMultihotDim)
626625
self.register_buffer("multihotOutput", multihotBuffer, persistent=False)
627626

628627
# most of the observation is a 2D array of bytes, but the end

0 commit comments

Comments
 (0)