Skip to content

Commit 16335cb

Browse files
[hotfix] fix aten default bug (#2158)
1 parent a4b4bb0 commit 16335cb

File tree

10 files changed

+137
-122
lines changed

10 files changed

+137
-122
lines changed

colossalai/fx/profiler/opcount.py

Lines changed: 122 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Callable, List
88

99
import torch
10+
from packaging import version
1011

1112
aten = torch.ops.aten
1213

@@ -188,131 +189,136 @@ def zero_flop_jit(*args):
188189
return 0
189190

190191

191-
flop_mapping = {
192+
if version.parse(torch.__version__) >= version.parse('1.12.0'):
193+
flop_mapping = {
192194
# gemm
193-
aten.mm.default: matmul_flop_jit,
194-
aten.matmul.default: matmul_flop_jit,
195-
aten.addmm.default: addmm_flop_jit,
196-
aten.bmm.default: bmm_flop_jit,
195+
aten.mm.default: matmul_flop_jit,
196+
aten.matmul.default: matmul_flop_jit,
197+
aten.addmm.default: addmm_flop_jit,
198+
aten.bmm.default: bmm_flop_jit,
197199

198200
# convolution
199-
aten.convolution.default: conv_flop_jit,
200-
aten._convolution.default: conv_flop_jit,
201-
aten.convolution_backward.default: conv_backward_flop_jit,
201+
aten.convolution.default: conv_flop_jit,
202+
aten._convolution.default: conv_flop_jit,
203+
aten.convolution_backward.default: conv_backward_flop_jit,
202204

203205
# normalization
204-
aten.native_batch_norm.default: batchnorm_flop_jit,
205-
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
206-
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
207-
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
208-
aten.native_layer_norm.default: norm_flop_counter(2, 0),
209-
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
206+
aten.native_batch_norm.default: batchnorm_flop_jit,
207+
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
208+
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
209+
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
210+
aten.native_layer_norm.default: norm_flop_counter(2, 0),
211+
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
210212

211213
# pooling
212-
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
213-
aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
214-
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
215-
aten.avg_pool3d.default: elementwise_flop_counter(1, 0),
216-
aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
217-
aten.max_pool1d.default: elementwise_flop_counter(1, 0),
218-
aten.max_pool2d.default: elementwise_flop_counter(1, 0),
219-
aten.max_pool3d.default: elementwise_flop_counter(1, 0),
220-
aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0),
221-
aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0),
222-
aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1),
223-
aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0),
224-
aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1),
225-
aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0),
226-
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
227-
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
228-
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
229-
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
230-
aten.embedding.default: elementwise_flop_counter(1, 0),
231-
}
232-
233-
elementwise_flop_aten = [
214+
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
215+
aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
216+
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
217+
aten.avg_pool3d.default: elementwise_flop_counter(1, 0),
218+
aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
219+
aten.max_pool1d.default: elementwise_flop_counter(1, 0),
220+
aten.max_pool2d.default: elementwise_flop_counter(1, 0),
221+
aten.max_pool3d.default: elementwise_flop_counter(1, 0),
222+
aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0),
223+
aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0),
224+
aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1),
225+
aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0),
226+
aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1),
227+
aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0),
228+
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
229+
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
230+
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
231+
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
232+
aten.embedding.default: elementwise_flop_counter(1, 0),
233+
}
234+
235+
elementwise_flop_aten = [
234236
# basic op
235-
aten.add.Tensor,
236-
aten.add_.Tensor,
237-
aten.div.Tensor,
238-
aten.div_.Tensor,
239-
aten.div.Scalar,
240-
aten.div_.Scalar,
241-
aten.mul.Tensor,
242-
aten.mul.Scalar,
243-
aten.mul_.Tensor,
244-
aten.neg.default,
245-
aten.pow.Tensor_Scalar,
246-
aten.rsub.Scalar,
247-
aten.sum.default,
248-
aten.sum.dim_IntList,
249-
aten.mean.dim,
237+
aten.add.Tensor,
238+
aten.add_.Tensor,
239+
aten.div.Tensor,
240+
aten.div_.Tensor,
241+
aten.div.Scalar,
242+
aten.div_.Scalar,
243+
aten.mul.Tensor,
244+
aten.mul.Scalar,
245+
aten.mul_.Tensor,
246+
aten.neg.default,
247+
aten.pow.Tensor_Scalar,
248+
aten.rsub.Scalar,
249+
aten.sum.default,
250+
aten.sum.dim_IntList,
251+
aten.mean.dim,
250252

251253
# activation op
252-
aten.hardswish.default,
253-
aten.hardswish_.default,
254-
aten.hardswish_backward.default,
255-
aten.hardtanh.default,
256-
aten.hardtanh_.default,
257-
aten.hardtanh_backward.default,
258-
aten.hardsigmoid_backward.default,
259-
aten.hardsigmoid.default,
260-
aten.gelu.default,
261-
aten.gelu_backward.default,
262-
aten.silu.default,
263-
aten.silu_.default,
264-
aten.silu_backward.default,
265-
aten.sigmoid.default,
266-
aten.sigmoid_backward.default,
267-
aten._softmax.default,
268-
aten._softmax_backward_data.default,
269-
aten.relu_.default,
270-
aten.relu.default,
271-
aten.tanh.default,
272-
aten.tanh_backward.default,
273-
aten.threshold_backward.default,
254+
aten.hardswish.default,
255+
aten.hardswish_.default,
256+
aten.hardswish_backward.default,
257+
aten.hardtanh.default,
258+
aten.hardtanh_.default,
259+
aten.hardtanh_backward.default,
260+
aten.hardsigmoid_backward.default,
261+
aten.hardsigmoid.default,
262+
aten.gelu.default,
263+
aten.gelu_backward.default,
264+
aten.silu.default,
265+
aten.silu_.default,
266+
aten.silu_backward.default,
267+
aten.sigmoid.default,
268+
aten.sigmoid_backward.default,
269+
aten._softmax.default,
270+
aten._softmax_backward_data.default,
271+
aten.relu_.default,
272+
aten.relu.default,
273+
aten.tanh.default,
274+
aten.tanh_backward.default,
275+
aten.threshold_backward.default,
274276

275277
# dropout
276-
aten.native_dropout.default,
277-
aten.native_dropout_backward.default,
278-
]
279-
280-
for op in elementwise_flop_aten:
281-
flop_mapping[op] = elementwise_flop_counter(1, 0)
282-
283-
# TODO: this will be removed in future
284-
zero_flop_aten = [
285-
aten.as_strided.default,
286-
aten.as_strided_.default,
287-
aten.bernoulli_.float,
288-
aten.cat.default,
289-
aten.clone.default,
290-
aten.copy_.default,
291-
aten.detach.default,
292-
aten.expand.default,
293-
aten.empty_like.default,
294-
aten.new_empty.default,
295-
aten.new_empty_strided.default,
296-
aten.ones_like.default,
297-
aten._reshape_alias.default,
298-
aten.select.int,
299-
aten.select_backward.default,
300-
aten.squeeze.dim,
301-
aten.slice.Tensor,
302-
aten.slice_backward.default,
303-
aten.split.Tensor,
304-
aten.permute.default,
305-
aten.t.default,
306-
aten.transpose.int,
307-
aten._to_copy.default,
308-
aten.unsqueeze.default,
309-
aten.unbind.int,
310-
aten._unsafe_view.default,
311-
aten.view.default,
312-
aten.where.self,
313-
aten.zero_.default,
314-
aten.zeros_like.default,
315-
]
316-
317-
for op in zero_flop_aten:
318-
flop_mapping[op] = zero_flop_jit
278+
aten.native_dropout.default,
279+
aten.native_dropout_backward.default,
280+
]
281+
for op in elementwise_flop_aten:
282+
flop_mapping[op] = elementwise_flop_counter(1, 0)
283+
284+
# TODO: this will be removed in future
285+
zero_flop_aten = [
286+
aten.as_strided.default,
287+
aten.as_strided_.default,
288+
aten.bernoulli_.float,
289+
aten.cat.default,
290+
aten.clone.default,
291+
aten.copy_.default,
292+
aten.detach.default,
293+
aten.expand.default,
294+
aten.empty_like.default,
295+
aten.new_empty.default,
296+
aten.new_empty_strided.default,
297+
aten.ones_like.default,
298+
aten._reshape_alias.default,
299+
aten.select.int,
300+
aten.select_backward.default,
301+
aten.squeeze.dim,
302+
aten.slice.Tensor,
303+
aten.slice_backward.default,
304+
aten.split.Tensor,
305+
aten.permute.default,
306+
aten.t.default,
307+
aten.transpose.int,
308+
aten._to_copy.default,
309+
aten.unsqueeze.default,
310+
aten.unbind.int,
311+
aten._unsafe_view.default,
312+
aten.view.default,
313+
aten.where.self,
314+
aten.zero_.default,
315+
aten.zeros_like.default,
316+
]
317+
318+
for op in zero_flop_aten:
319+
flop_mapping[op] = zero_flop_jit
320+
321+
else:
322+
flop_mapping = {}
323+
elementwise_flop_aten = {}
324+
zero_flop_aten = {}

tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ def forward(self, x1):
207207
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence
208208

209209

210+
@run_on_environment_flag(name='AUTO_PARALLEL')
210211
@parameterize('op', [torch.add])
211212
@parameterize('other_dim', [1, 2])
212-
@run_on_environment_flag(name='AUTO_PARALLEL')
213213
@pytest.mark.dist
214214
@rerun_if_address_is_in_use()
215215
def test_binary_elementwise_handler(op, other_dim):

tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
203203
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
204204

205205

206-
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
207206
@run_on_environment_flag(name='AUTO_PARALLEL')
207+
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
208208
@pytest.mark.dist
209209
@rerun_if_address_is_in_use()
210210
def test_bmm_handler(module):

tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def forward(self, input, other):
2323
return x
2424

2525

26+
@run_on_environment_flag(name='AUTO_PARALLEL')
2627
def test_getitem_from_tensor_handler():
2728
model = GetItemFromTensorModel()
2829
tracer = ColoTracer()
@@ -96,6 +97,7 @@ def forward(self, input):
9697
return x
9798

9899

100+
@run_on_environment_flag(name='AUTO_PARALLEL')
99101
def test_getitem_from_tuple_handler():
100102
model = GetItemFromTupleModel()
101103
tracer = ColoTracer()

tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,8 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
308308
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
309309

310310

311-
@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
312311
@run_on_environment_flag(name='AUTO_PARALLEL')
312+
@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
313313
@pytest.mark.dist
314314
@rerun_if_address_is_in_use()
315315
def test_linear_handler(input_shape, bias=False):

tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import torch
33
import torch.nn as nn
44

5-
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \
6-
NormPoolingHandler
7-
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
5+
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
6+
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
87
from colossalai.device.device_mesh import DeviceMesh
98
from colossalai.fx import ColoGraphModule, ColoTracer
109
from colossalai.fx.tracer.meta_patch.patched_module import linear
1110
from colossalai.testing.pytest_wrapper import run_on_environment_flag
1211

1312

13+
@run_on_environment_flag(name='AUTO_PARALLEL')
1414
def test_norm_pool_handler():
1515
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
1616
tracer = ColoTracer()

tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def forward(self, input, other):
2020
return reshape_node
2121

2222

23+
@run_on_environment_flag(name='AUTO_PARALLEL')
2324
def test_reshape_handler():
2425
model = ReshapeModel()
2526
tracer = ColoTracer()

tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
66
from colossalai.device.device_mesh import DeviceMesh
77
from colossalai.fx import ColoGraphModule, ColoTracer
8+
from colossalai.testing.pytest_wrapper import run_on_environment_flag
89

910

1011
class TensorConstructorModel(nn.Module):
@@ -18,6 +19,7 @@ def forward(self, x):
1819
return x
1920

2021

22+
@run_on_environment_flag(name='AUTO_PARALLEL')
2123
def test_where_handler():
2224
model = TensorConstructorModel()
2325
tracer = ColoTracer()

tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def forward(self, input, other):
2222
return relu_node
2323

2424

25+
@run_on_environment_flag(name='AUTO_PARALLEL')
2526
def test_elementwise_handler():
2627
model = ReLuModel()
2728
tracer = ColoTracer()

tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from colossalai.device.device_mesh import DeviceMesh
1212
from colossalai.fx import ColoGraphModule, ColoTracer
13+
from colossalai.testing.pytest_wrapper import run_on_environment_flag
1314

1415

1516
def _param_resharding_cost_assertion(node):
@@ -51,6 +52,7 @@ def forward(self, x):
5152
return x
5253

5354

55+
@run_on_environment_flag(name='AUTO_PARALLEL')
5456
def test_linear_module():
5557
model = LinearModel(4, 8)
5658
physical_mesh_id = torch.arange(0, 4)
@@ -86,6 +88,7 @@ def test_linear_module():
8688
_param_resharding_cost_assertion(linear_node)
8789

8890

91+
@run_on_environment_flag(name='AUTO_PARALLEL')
8992
def test_conv_module():
9093
model = ConvModel(3, 6, 2)
9194
physical_mesh_id = torch.arange(0, 4)

0 commit comments

Comments
 (0)