Commit 3bfe071
Add device guard for xpu conv on multi device (pytorch#153345)
Add device guard for xpu conv on multi device (pytorch#153067)
# Motivation
fixes pytorch#153022
The root cause is that the XPU backend registers the convolution op using `m.impl`, which bypasses the device guard logic typically added by the code generation system. This can lead to unexpected behavior if the current device isn't explicitly set.
# Additional Context
run the following script
```python
import torch
import torchvision.models as models
torch.manual_seed(0)
model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
model.eval()
data = torch.rand(1, 3, 224, 224)
device = torch.device('xpu:1') # 'xpu:0'
model = model.to(device=device, dtype=torch.float16)
data = data.to(device, dtype=torch.float16)
with torch.no_grad():
ret = model(data)
print(ret)
print("Execution finished")
```
The output is
```bash
-9.2102e-02, -7.7588e-01, -1.4111e+00, -9.2383e-01, 6.4551e-01,
-6.0730e-03, -7.8271e-01, -1.1904e+00, -4.1602e-01, 3.2715e-02,
-4.9854e-01, -6.3623e-01, -8.5107e-01, -6.8555e-01, -9.4434e-01,
-8.8672e-01, -6.7969e-01, -6.9824e-01, -2.8882e-01, 2.0312e+00]],
device='xpu:1', dtype=torch.float16)
Execution finished
```
Pull Request resolved: pytorch#153067
Approved by: https://github.com/albanD, https://github.com/EikanWang
(cherry picked from commit e06a080)
Co-authored-by: Yu, Guangye <[email protected]>1 parent fa98236 commit 3bfe071
File tree
2 files changed
+35
-0
lines changed- aten/src/ATen/native/mkldnn/xpu
- test/xpu
2 files changed
+35
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
401 | 401 | | |
402 | 402 | | |
403 | 403 | | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
404 | 409 | | |
405 | 410 | | |
406 | 411 | | |
| |||
611 | 616 | | |
612 | 617 | | |
613 | 618 | | |
| 619 | + | |
| 620 | + | |
614 | 621 | | |
615 | 622 | | |
616 | 623 | | |
| |||
675 | 682 | | |
676 | 683 | | |
677 | 684 | | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
678 | 689 | | |
679 | 690 | | |
680 | 691 | | |
| |||
701 | 712 | | |
702 | 713 | | |
703 | 714 | | |
| 715 | + | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
704 | 719 | | |
705 | 720 | | |
706 | 721 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
| 3 | + | |
3 | 4 | | |
4 | 5 | | |
5 | 6 | | |
| |||
1191 | 1192 | | |
1192 | 1193 | | |
1193 | 1194 | | |
| 1195 | + | |
| 1196 | + | |
| 1197 | + | |
| 1198 | + | |
| 1199 | + | |
| 1200 | + | |
| 1201 | + | |
| 1202 | + | |
| 1203 | + | |
| 1204 | + | |
| 1205 | + | |
| 1206 | + | |
| 1207 | + | |
| 1208 | + | |
| 1209 | + | |
| 1210 | + | |
| 1211 | + | |
| 1212 | + | |
| 1213 | + | |
1194 | 1214 | | |
1195 | 1215 | | |
1196 | 1216 | | |
| |||
0 commit comments