Skip to content

Fix device allocation#1

Open
Kuzaiherba wants to merge 7 commits intomainfrom
fix/metrics_weights_default
Open

Fix device allocation#1
Kuzaiherba wants to merge 7 commits intomainfrom
fix/metrics_weights_default

Conversation

@Kuzaiherba
Copy link
Owner

@Kuzaiherba Kuzaiherba commented Sep 15, 2025

Description

This PR fixes device allocation issues in PyTorch tensor creation within the metrics module. The fix ensures that newly created weight tensors are placed on the same device as the target tensors, preventing potential device mismatch errors during tensor operations.

Example / Current workflow

Include a sample workflow to either (a) reproduce the bug with current codebase or (b) showcase the deficiency does this PR seeks to address

from chemprop.nn import BinaryClassificationFFN
from chemprop.nn.metrics import BCELoss


task_dim = 3
predictor = BinaryClassificationFFN(
    n_tasks=task_dim,
    criterion=BCELoss(task_weights=torch.ones(task_dim)),
)

device = "mps"
predictor.to(device=device)
predictor.criterion(
    torch.ones((10, task_dim)).to(device=device),
    torch.ones((10, task_dim)).to(device=device)
)
>>> RuntimeError: Encountered different devices in metric calculation (see stacktrace for details). This could be due to the metric class not being on the same device as input. Instead of `metric=BCELoss(...)` try to do `metric=BCELoss(...).to(device)` where device corresponds to the device of the input.

Bugfix / Desired workflow

Include either (a) the same workflow from above with the correct output produced via this PR (b) some (pseudo)code containing the new workflow that this PR will (seek to) implement

from chemprop.nn import BinaryClassificationFFN
from chemprop.nn.metrics import BCELoss


task_dim = 3
predictor = BinaryClassificationFFN(
    n_tasks=task_dim,
    criterion=BCELoss(task_weights=torch.ones(task_dim)),
)

device = "mps"
predictor.to(device=device)
predictor.criterion(
    torch.ones((10, task_dim)).to(device=device),
    torch.ones((10, task_dim)).to(device=device)
)
>>> Out[1]: tensor(0.3133, device='mps:0')

Questions

If there are open questions about implementation strategy or scope of the PR, include them here

Relevant issues

If appropriate, please tag them here and include a quick summary

Checklist

  • linted with flake8?
  • (if appropriate) unit tests added?

@Kuzaiherba Kuzaiherba requested a review from Copilot September 15, 2025 21:56
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes device allocation issues in PyTorch tensor creation within the metrics module. The fix ensures that newly created weight tensors are placed on the same device as the target tensors, preventing potential device mismatch errors during tensor operations.

  • Added explicit device specification to torch.ones() calls for weight tensor creation
  • Applied the fix consistently across three different metric update methods
  • Maintained the same conditional logic while ensuring device compatibility

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@Kuzaiherba Kuzaiherba self-assigned this Sep 15, 2025
@Kuzaiherba Kuzaiherba added the bug Something isn't working label Sep 15, 2025
@Kuzaiherba Kuzaiherba marked this pull request as ready for review September 15, 2025 22:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants