Skip to content

Commit 6e8d7d1

Browse files
committed
added device to start_indices and end_indices
1 parent 262c494 commit 6e8d7d1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/cehrbert/linear_prob/compute_cehrbert_features.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def extract_averaged_embeddings_from_packed_sequence(
4545
boundary_indices = (mask == 0).nonzero(as_tuple=False).flatten()
4646

4747
# Add start and end manually
48-
start_indices = torch.cat([torch.tensor([-1]), boundary_indices])
49-
end_indices = torch.cat([boundary_indices, torch.tensor([mask.size(0)])])
48+
start_indices = torch.cat([torch.tensor([-1], device=boundary_indices.device), boundary_indices])
49+
end_indices = torch.cat([boundary_indices, torch.tensor([mask.size(0)], device=boundary_indices.device)])
5050

5151
# Step 2: Extract embeddings between boundaries and average
5252
sample_embeddings = []

0 commit comments

Comments
 (0)