File tree Expand file tree Collapse file tree 8 files changed +16
-8
lines changed Expand file tree Collapse file tree 8 files changed +16
-8
lines changed Original file line number Diff line number Diff line change @@ -119,7 +119,8 @@ def load_balancing_loss_func(
119
119
router_per_expert_attention_mask , dim = 0
120
120
)
121
121
122
- rank = routing_weights .shape [1 ] * int (routing_weights .device .index )
122
+ device_index = routing_weights .device .index if routing_weights .device .index is not None else 0
123
+ rank = routing_weights .shape [1 ] * int (device_index )
123
124
overall_loss = torch .sum (
124
125
tokens_per_expert [:, rank : rank + routing_weights .shape [1 ]] * router_prob_per_expert .unsqueeze (0 )
125
126
)
Original file line number Diff line number Diff line change @@ -1647,7 +1647,8 @@ def load_balancing_loss_func(
1647
1647
router_per_expert_attention_mask , dim = 0
1648
1648
)
1649
1649
1650
- rank = routing_weights .shape [1 ] * int (routing_weights .device .index )
1650
+ device_index = routing_weights .device .index if routing_weights .device .index is not None else 0
1651
+ rank = routing_weights .shape [1 ] * int (device_index )
1651
1652
overall_loss = torch .sum (
1652
1653
tokens_per_expert [:, rank : rank + routing_weights .shape [1 ]] * router_prob_per_expert .unsqueeze (0 )
1653
1654
)
Original file line number Diff line number Diff line change @@ -918,7 +918,8 @@ def load_balancing_loss_func(
918
918
router_per_expert_attention_mask , dim = 0
919
919
)
920
920
921
- rank = routing_weights .shape [1 ] * int (routing_weights .device .index )
921
+ device_index = routing_weights .device .index if routing_weights .device .index is not None else 0
922
+ rank = routing_weights .shape [1 ] * int (device_index )
922
923
overall_loss = torch .sum (
923
924
tokens_per_expert [:, rank : rank + routing_weights .shape [1 ]] * router_prob_per_expert .unsqueeze (0 )
924
925
)
Original file line number Diff line number Diff line change @@ -148,7 +148,8 @@ def load_balancing_loss_func(
148
148
router_per_expert_attention_mask , dim = 0
149
149
)
150
150
151
- rank = routing_weights .shape [1 ] * int (routing_weights .device .index )
151
+ device_index = routing_weights .device .index if routing_weights .device .index is not None else 0
152
+ rank = routing_weights .shape [1 ] * int (device_index )
152
153
overall_loss = torch .sum (
153
154
tokens_per_expert [:, rank : rank + routing_weights .shape [1 ]] * router_prob_per_expert .unsqueeze (0 )
154
155
)
Original file line number Diff line number Diff line change @@ -129,7 +129,8 @@ def load_balancing_loss_func(
129
129
router_per_expert_attention_mask , dim = 0
130
130
)
131
131
132
- rank = routing_weights .shape [1 ] * int (routing_weights .device .index )
132
+ device_index = routing_weights .device .index if routing_weights .device .index is not None else 0
133
+ rank = routing_weights .shape [1 ] * int (device_index )
133
134
overall_loss = torch .sum (
134
135
tokens_per_expert [:, rank : rank + routing_weights .shape [1 ]] * router_prob_per_expert .unsqueeze (0 )
135
136
)
Original file line number Diff line number Diff line change @@ -118,7 +118,8 @@ def load_balancing_loss_func(
118
118
router_per_expert_attention_mask , dim = 0
119
119
)
120
120
121
- rank = routing_weights .shape [1 ] * int (routing_weights .device .index )
121
+ device_index = routing_weights .device .index if routing_weights .device .index is not None else 0
122
+ rank = routing_weights .shape [1 ] * int (device_index )
122
123
overall_loss = torch .sum (
123
124
tokens_per_expert [:, rank : rank + routing_weights .shape [1 ]] * router_prob_per_expert .unsqueeze (0 )
124
125
)
Original file line number Diff line number Diff line change @@ -134,7 +134,8 @@ def load_balancing_loss_func(
134
134
router_per_expert_attention_mask , dim = 0
135
135
)
136
136
137
- rank = routing_weights .shape [1 ] * int (routing_weights .device .index )
137
+ device_index = routing_weights .device .index if routing_weights .device .index is not None else 0
138
+ rank = routing_weights .shape [1 ] * int (device_index )
138
139
overall_loss = torch .sum (
139
140
tokens_per_expert [:, rank : rank + routing_weights .shape [1 ]] * router_prob_per_expert .unsqueeze (0 )
140
141
)
Original file line number Diff line number Diff line change @@ -137,7 +137,8 @@ def load_balancing_loss_func(
137
137
router_per_expert_attention_mask , dim = 0
138
138
)
139
139
140
- rank = routing_weights .shape [1 ] * int (routing_weights .device .index )
140
+ device_index = routing_weights .device .index if routing_weights .device .index is not None else 0
141
+ rank = routing_weights .shape [1 ] * int (device_index )
141
142
overall_loss = torch .sum (
142
143
tokens_per_expert [:, rank : rank + routing_weights .shape [1 ]] * router_prob_per_expert .unsqueeze (0 )
143
144
)
You can’t perform that action at this time.
0 commit comments