Skip to content

Commit 910b319

Browse files
Fix CI: Tests failing on CPU due to torch.device('cpu').index being None (#39933)
replace routing_weights.device.index with a
1 parent 369c99d commit 910b319

File tree

8 files changed

+16
-8
lines changed

8 files changed

+16
-8
lines changed

src/transformers/models/granitemoe/modeling_granitemoe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def load_balancing_loss_func(
119119
router_per_expert_attention_mask, dim=0
120120
)
121121

122-
rank = routing_weights.shape[1] * int(routing_weights.device.index)
122+
device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
123+
rank = routing_weights.shape[1] * int(device_index)
123124
overall_loss = torch.sum(
124125
tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
125126
)

src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1647,7 +1647,8 @@ def load_balancing_loss_func(
16471647
router_per_expert_attention_mask, dim=0
16481648
)
16491649

1650-
rank = routing_weights.shape[1] * int(routing_weights.device.index)
1650+
device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
1651+
rank = routing_weights.shape[1] * int(device_index)
16511652
overall_loss = torch.sum(
16521653
tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
16531654
)

src/transformers/models/granitemoeshared/modeling_granitemoeshared.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,8 @@ def load_balancing_loss_func(
918918
router_per_expert_attention_mask, dim=0
919919
)
920920

921-
rank = routing_weights.shape[1] * int(routing_weights.device.index)
921+
device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
922+
rank = routing_weights.shape[1] * int(device_index)
922923
overall_loss = torch.sum(
923924
tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
924925
)

src/transformers/models/jamba/modeling_jamba.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def load_balancing_loss_func(
148148
router_per_expert_attention_mask, dim=0
149149
)
150150

151-
rank = routing_weights.shape[1] * int(routing_weights.device.index)
151+
device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
152+
rank = routing_weights.shape[1] * int(device_index)
152153
overall_loss = torch.sum(
153154
tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
154155
)

src/transformers/models/jetmoe/modeling_jetmoe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def load_balancing_loss_func(
129129
router_per_expert_attention_mask, dim=0
130130
)
131131

132-
rank = routing_weights.shape[1] * int(routing_weights.device.index)
132+
device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
133+
rank = routing_weights.shape[1] * int(device_index)
133134
overall_loss = torch.sum(
134135
tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
135136
)

src/transformers/models/olmoe/modeling_olmoe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def load_balancing_loss_func(
118118
router_per_expert_attention_mask, dim=0
119119
)
120120

121-
rank = routing_weights.shape[1] * int(routing_weights.device.index)
121+
device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
122+
rank = routing_weights.shape[1] * int(device_index)
122123
overall_loss = torch.sum(
123124
tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
124125
)

src/transformers/models/phimoe/modeling_phimoe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def load_balancing_loss_func(
134134
router_per_expert_attention_mask, dim=0
135135
)
136136

137-
rank = routing_weights.shape[1] * int(routing_weights.device.index)
137+
device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
138+
rank = routing_weights.shape[1] * int(device_index)
138139
overall_loss = torch.sum(
139140
tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
140141
)

src/transformers/models/qwen2_moe/modeling_qwen2_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def load_balancing_loss_func(
137137
router_per_expert_attention_mask, dim=0
138138
)
139139

140-
rank = routing_weights.shape[1] * int(routing_weights.device.index)
140+
device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
141+
rank = routing_weights.shape[1] * int(device_index)
141142
overall_loss = torch.sum(
142143
tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
143144
)

0 commit comments

Comments
 (0)