diff --git a/spacy_curated_transformers/models/architectures.py b/spacy_curated_transformers/models/architectures.py index 51c2619..5306d6b 100644 --- a/spacy_curated_transformers/models/architectures.py +++ b/spacy_curated_transformers/models/architectures.py @@ -734,7 +734,11 @@ def _convert_inputs( span = X[i] span_len = span.shape[0] Xt[i, :span_len] = span - Xt = xp2torch(Xt) + if ops.device_type == 'gpu': + device = torch.device(f"cuda:{ops.device_id}") + Xt = xp2torch(Xt, device=device) + else: + Xt = xp2torch(Xt) def convert_from_torch_backward(d_inputs: Any): # No gradients for the inputs.