diff --git a/convbench/conv_utils.py b/convbench/conv_utils.py index 4f48d89..dbf4833 100644 --- a/convbench/conv_utils.py +++ b/convbench/conv_utils.py @@ -13,6 +13,8 @@ CONV_Q = r"""%c0_i32 = arith.constant 0 : i32 %11 = linalg.conv_2d_{CONV_TYPE}_q {{dilations = dense<1> : vector<2xi64>, strides = dense<{STRIDE}> : vector<2xi64>}} ins(%arg0, %arg1, %c0_i32, %c0_i32 : tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>, i32, i32) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>""" +CONV_3D = r"""%11 = linalg.conv_3d_{CONV_TYPE} {dilations = dense<1> : tensor<3xi64>, strides = dense<{STRIDE}> : tensor<3xi64>} ins (%arg0, %arg1: tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>>) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>""" + TEST = r"""util.func public @{FUNC_NAME}({FUNC_ARGS}) -> tensor<{OUT_TYPE}> {{{CONSTANT_INPUTS} %cst = arith.constant {ZERO} : {OUT_ELEM_TYPE} %9 = tensor.empty() : tensor<{OUT_TYPE}> @@ -33,30 +35,36 @@ class ConvConfig: Q: int F: int S: int + is_grouped_conv: bool + G: int # group count + is_3D_conv: bool + D: int # input depth + R: int # filter depth + S_D: int # stride along depth OP: str input_dtype: str output_dtype: str def get_name(self) -> str: - return self.OP + "_" + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + "_" + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + "_stride" + str(self.S) + return self.OP + "_" + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + "_" + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + "_stride" + str(self.S) + "_groupcount" + str(self.G) def get_img_shape(self) -> str: + in_h = self.H * self.S + self.P - 1 + in_w = self.W * self.S + self.Q - 1 if "nhwc" in self.OP: - in_h = self.H * self.S + self.P - 1 - in_w = self.W * self.S + self.Q - 1 return str(self.N) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(self.C) + "x" + self.input_dtype if "nchw" in self.OP: - in_h = self.H * self.S + self.P - 1 - in_w = self.W * self.S + self.Q - 1 return str(self.N) + "x" + str(self.C) + "x" + str(in_h) + "x" + str(in_w) + "x" + self.input_dtype - + if "ngchw" in operation: + return str(self.N) + "x" + str(self.G) + "x" + str(self.C) + "x" + str(in_h) + "x" + str(in_w) + "x" + self.input_dtype def get_kernel_shape(self) -> str: if "nhwc" in self.OP: return str(self.P) + "x" + str(self.Q) + "x" + str(self.C) + "x" + str(self.F) + "x" + self.input_dtype if "nchw" in self.OP: return str(self.F) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype - + if "ngchw" in operation: + return str(self.F) + "x" + str(self.G) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype def get_byte_count(self) -> int: dtype_bits_map = { @@ -73,15 +81,16 @@ def get_byte_count(self) -> int: in_h = self.H * self.S + self.P - 1 in_w = self.W * self.S + self.Q - 1 input_channels = self.C + group_count = self.G output_channels = self.F output_width = self.W output_height = self.H k_width = self.Q k_height = self.P byte_count = ( - (batch * input_channels * in_w * in_h * bytes_per_input) - + (batch * output_channels * output_width * output_height * bytes_per_output) - + (k_width * k_height * input_channels * output_channels * bytes_per_input) + (batch * group_count * input_channels * in_w * in_h * bytes_per_input) + + (batch * group_count * output_channels * output_width * output_height * bytes_per_output) + + (group_count * k_width * k_height * input_channels * output_channels * bytes_per_input) ) return byte_count @@ -90,6 +99,7 @@ def get_flops(self) -> int: in_h = self.H * self.S + self.P - 1 in_w = self.W * self.S + self.Q - 1 input_channels = self.C + group_count = self.G output_channels = self.F output_width = self.W output_height = self.H @@ -97,7 +107,7 @@ def get_flops(self) -> int: k_height = self.P operation_per_pixel = k_width * k_height * input_channels * 2 output_pixels_per_batch = output_width * output_height * output_channels - flops = operation_per_pixel * output_pixels_per_batch * batch + flops = operation_per_pixel * output_pixels_per_batch * group_count * batch return flops def generate_mlir(config: ConvConfig): @@ -109,11 +119,16 @@ def generate_mlir(config: ConvConfig): q = config.Q f = config.F stride = config.S + g = config.G + d = config.D + r = config.R + s_d = config.S_D operation = config.OP dtypes = f"{config.input_dtype}x{config.input_dtype}x{config.output_dtype}" elem_types = dtypes.split("x") in_h = str(int(h) * int(stride) + int(p) - 1) in_w = str(int(w) * int(stride) + int(q) - 1) + in_d = str(int(d) * int(s_d) + int(r) - 1) if "nhwc" in operation: conv_type = "nhwc_hwcf" lhs = str(n) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(c) + "x" + str(elem_types[0]) @@ -124,6 +139,17 @@ def generate_mlir(config: ConvConfig): lhs = str(n) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0]) rhs = str(f) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1]) out = str(n) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2]) + if "ngchw" in operation: + conv_type = "ngchw_fgchw" + lhs = str(n) + "x" + str(g) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0]) + rhs = str(f) + "x" + str(g) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1]) + out = str(n) + "x" + str(g) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2]) + if "ncdhw" in operation: + conv_type = "ncdhw_fcdhw" + lhs = str(n) + "x" + str(c) + "x" + str(in_d) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0]) + rhs = str(f) + "x" + str(c) + "x" + str(r) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1]) + out = str(n) + "x" + str(f) + "x" + str(d) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2]) + one = "1" zero = "0" if (elem_types[0][0] == "f"): @@ -132,6 +158,8 @@ def generate_mlir(config: ConvConfig): conv_template = CONV if "q" in operation: conv_template = CONV_Q + if config.is_3D_conv: + conv_template = CONV_3D operation = conv_template.format( INPUT_TYPE=lhs, FILTER_TYPE=rhs, diff --git a/convbench/problems.py b/convbench/problems.py index a61272a..70dac1b 100644 --- a/convbench/problems.py +++ b/convbench/problems.py @@ -4,49 +4,49 @@ def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]: configs = [] for B in [1, 2, 4, 8]: - configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) return configs def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]: configs = [] for B in [1, 2, 4, 8]: - configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, op, input_dtype, output_dtype)) - configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype)) return configs def get_conv_configs() -> list[tuple[str, ConvConfig]]: