diff --git a/backends/mediatek/partitioner.py b/backends/mediatek/partitioner.py index b9e342e937d..f15c058507a 100644 --- a/backends/mediatek/partitioner.py +++ b/backends/mediatek/partitioner.py @@ -81,6 +81,7 @@ def ops_to_not_decompose( torch.ops.aten.upsample_bilinear2d.vec, torch.ops.aten.upsample_nearest2d.default, torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten._safe_softmax.default, ] return (ops_not_decompose, None) diff --git a/backends/mediatek/scripts/mtk_build.sh b/backends/mediatek/scripts/mtk_build.sh index 6c935b3c80c..97b766807ff 100755 --- a/backends/mediatek/scripts/mtk_build.sh +++ b/backends/mediatek/scripts/mtk_build.sh @@ -33,6 +33,7 @@ rm -rf cmake-android-out && mkdir cmake-android-out && cd cmake-android-out cmake -DBUCK2="$BUCK_PATH" \ -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-26 \ -DEXECUTORCH_BUILD_NEURON=ON \ -DNEURON_BUFFER_ALLOCATOR_LIB="$NEURON_BUFFER_ALLOCATOR_LIB" \ .. diff --git a/examples/mediatek/aot_utils/oss_utils/utils.py b/examples/mediatek/aot_utils/oss_utils/utils.py index 25362788e31..f788735fa95 100755 --- a/examples/mediatek/aot_utils/oss_utils/utils.py +++ b/examples/mediatek/aot_utils/oss_utils/utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Optional +from typing import Optional, Dict import torch from executorch import exir @@ -23,6 +23,8 @@ def build_executorch_binary( file_name, dataset, quant_dtype: Optional[Precision] = None, + skip_op_name: Optional[set] = None, + skip_op_type: Optional[set] = None ): if quant_dtype is not None: quantizer = NeuropilotQuantizer() @@ -46,14 +48,12 @@ def build_executorch_binary( from executorch.exir.program._program import to_edge_transform_and_lower edge_compile_config = exir.EdgeCompileConfig(_check_ir_validity=False) - # skipped op names are used for deeplabV3 model neuro_partitioner = NeuropilotPartitioner( [], - op_names_to_skip={ - "aten_convolution_default_106", - "aten_convolution_default_107", - }, + op_types_to_skip=skip_op_type, + op_names_to_skip=skip_op_name, ) + edge_prog = to_edge_transform_and_lower( aten_dialect, compile_config=edge_compile_config, diff --git a/examples/mediatek/model_export_scripts/dcgan.py b/examples/mediatek/model_export_scripts/dcgan.py new file mode 100755 index 00000000000..411cbd4abae --- /dev/null +++ b/examples/mediatek/model_export_scripts/dcgan.py @@ -0,0 +1,102 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) +import json +import os +import numpy as np +import argparse + +import torch +import dcgan_main +from executorch.backends.mediatek import Precision +from aot_utils.oss_utils.utils import ( + build_executorch_binary, + make_output_dir, +) + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self, is_gen=True): + super(NhwcWrappedModel, self).__init__() + if is_gen: + self.dcgan = dcgan_main.Generator() + else: + self.dcgan = dcgan_main.Discriminator() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.dcgan(nchw_input1) + return output + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./dcgan", + default="./dcgan", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # prepare dummy data + inputG = torch.randn(1, 1, 1, 100) + inputD = torch.randn(1, 64, 64, 3) + + # build Generator + netG_instance = NhwcWrappedModel(True) + netG_pte_filename = "dcgan_netG_mtk" + build_executorch_binary( + netG_instance.eval(), + (torch.randn(1, 1, 1, 100),), + f"{args.artifact}/{netG_pte_filename}", + [(inputG,)], + quant_dtype=Precision.A8W8, + ) + + # build Discriminator + netD_instance = NhwcWrappedModel(False) + netD_pte_filename = "dcgan_netD_mtk" + build_executorch_binary( + netD_instance.eval(), + (torch.randn(1, 64, 64, 3),), + f"{args.artifact}/{netD_pte_filename}", + [(inputD,)], + quant_dtype=Precision.A8W8, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list_G.txt" + with open(input_list_file, "w") as f: + f.write("inputG_0_0.bin") + f.flush() + file_name = f"{args.artifact}/inputG_0_0.bin" + inputG.detach().numpy().tofile(file_name) + file_name = f"{args.artifact}/goldenG_0_0.bin" + goldenG = netG_instance(inputG) + goldenG.detach().numpy().tofile(file_name) + + input_list_file = f"{args.artifact}/input_list_D.txt" + with open(input_list_file, "w") as f: + f.write("inputD_0_0.bin") + f.flush() + file_name = f"{args.artifact}/inputD_0_0.bin" + inputD.detach().numpy().tofile(file_name) + file_name = f"{args.artifact}/goldenD_0_0.bin" + goldenD = netD_instance(inputD) + goldenD.detach().numpy().tofile(file_name) + diff --git a/examples/mediatek/model_export_scripts/dcgan_main.py b/examples/mediatek/model_export_scripts/dcgan_main.py new file mode 100755 index 00000000000..77976c5c63b --- /dev/null +++ b/examples/mediatek/model_export_scripts/dcgan_main.py @@ -0,0 +1,70 @@ +"""Ref https://github.com/pytorch/examples/blob/main/dcgan/main.py""" + +import torch.nn as nn + + +class Generator(nn.Module): + def __init__(self): + super().__init__() + self.main = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False), + nn.BatchNorm2d(64 * 8), + nn.ReLU(True), + # state size. (64*8) x 4 x 4 + nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(64 * 4), + nn.ReLU(True), + # state size. (64*4) x 8 x 8 + nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(64 * 2), + nn.ReLU(True), + # state size. (64*2) x 16 x 16 + nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(True), + # state size. (64) x 32 x 32 + nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False), + nn.Tanh() + # state size. (3) x 64 x 64 + ) + + def forward(self, input): + output = self.main(input) + return output + +# main_netG_input_shape = [1, 100, 1, 1] +# model = Generator() + + +class Discriminator(nn.Module): + def __init__(self): + super().__init__() + self.main = nn.Sequential( + # input is (3) x 64 x 64 + nn.Conv2d(3, 64, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + # state size. (64) x 32 x 32 + nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(64 * 2), + nn.LeakyReLU(0.2, inplace=True), + # state size. (64*2) x 16 x 16 + nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(64 * 4), + nn.LeakyReLU(0.2, inplace=True), + # state size. (64*4) x 8 x 8 + nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False), + nn.BatchNorm2d(64 * 8), + nn.LeakyReLU(0.2, inplace=True), + # state size. (64*8) x 4 x 4 + nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False), + nn.Sigmoid() + ) + + def forward(self, input): + output = self.main(input) + + return output.view(-1, 1).squeeze(1) + +# main_netD_input_shape = [1, 3, 64, 64] +# model = Discriminator() diff --git a/examples/mediatek/model_export_scripts/deeplab_v3.py b/examples/mediatek/model_export_scripts/deeplab_v3.py index da6766c0f54..f8113f006ef 100755 --- a/examples/mediatek/model_export_scripts/deeplab_v3.py +++ b/examples/mediatek/model_export_scripts/deeplab_v3.py @@ -3,16 +3,19 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os +import sys +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) import argparse -import os import random import numpy as np import torch from executorch.backends.mediatek import Precision -from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( +from aot_utils.oss_utils.utils import ( build_executorch_binary, ) from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model @@ -26,7 +29,7 @@ def __init__(self): def forward(self, input1): nchw_input1 = input1.permute(0, 3, 1, 2) nchw_output = self.deeplabv3(nchw_input1) - return nchw_output.permute(0, 2, 3, 1) + return nchw_output def get_dataset(data_size, dataset_dir, download): @@ -121,4 +124,8 @@ def get_dataset(data_size, dataset_dir, download): f"{args.artifact}/{pte_filename}", inputs, quant_dtype=Precision.A8W8, + skip_op_name = { + "aten_convolution_default_106", + "aten_convolution_default_107", + }, ) diff --git a/examples/mediatek/model_export_scripts/edsr.py b/examples/mediatek/model_export_scripts/edsr.py index 4192d67e569..6ed00364845 100755 --- a/examples/mediatek/model_export_scripts/edsr.py +++ b/examples/mediatek/model_export_scripts/edsr.py @@ -6,12 +6,15 @@ import argparse import os +import sys +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) import numpy as np import torch from executorch.backends.mediatek import Precision -from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( +from aot_utils.oss_utils.utils import ( build_executorch_binary, ) from executorch.examples.models.edsr import EdsrModel diff --git a/examples/mediatek/model_export_scripts/emformer_rnnt.py b/examples/mediatek/model_export_scripts/emformer_rnnt.py new file mode 100755 index 00000000000..aae34a257a3 --- /dev/null +++ b/examples/mediatek/model_export_scripts/emformer_rnnt.py @@ -0,0 +1,162 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) +import json +import os +import numpy as np +import argparse + +import torch +from executorch.examples.models.emformer_rnnt import ( + EmformerRnntTranscriberModel, + EmformerRnntPredictorModel, + EmformerRnntJoinerModel, +) +from executorch.backends.mediatek import Precision +from aot_utils.oss_utils.utils import build_executorch_binary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./emformer_rnnt", + default="./emformer_rnnt", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build Transcriber + print("Build Transcriber") + transcriber = EmformerRnntTranscriberModel() + t_model = transcriber.get_eager_model() + inputs = transcriber.get_example_inputs() + pte_filename = 'emformer_rnnt_t_mtk' + build_executorch_binary( + t_model.eval(), + inputs, + f"{args.artifact}/{pte_filename}", + [inputs], + quant_dtype=Precision.A8W8, + skip_op_type={ + "aten.where.self", + }, + skip_op_name={ + "aten_div_tensor_mode", + "aten_unsqueeze_copy_default", + "aten_unsqueeze_copy_default_1", + "aten_unsqueeze_copy_default_2", + "aten_unsqueeze_copy_default_3", + "aten_unsqueeze_copy_default_4", + "aten_unsqueeze_copy_default_5", + "aten_unsqueeze_copy_default_6", + "aten_unsqueeze_copy_default_7", + "aten_unsqueeze_copy_default_8", + "aten_unsqueeze_copy_default_9", + "aten_unsqueeze_copy_default_10", + "aten_unsqueeze_copy_default_11", + "aten_unsqueeze_copy_default_12", + "aten_unsqueeze_copy_default_13", + "aten_unsqueeze_copy_default_14", + "aten_unsqueeze_copy_default_15", + "aten_unsqueeze_copy_default_16", + "aten_unsqueeze_copy_default_17", + "aten_unsqueeze_copy_default_18", + "aten_unsqueeze_copy_default_19", + }, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list_t.txt" + with open(input_list_file, "w") as f: + f.write("input_t_0_0.bin input_t_0_1.bin") + f.flush() + for idx, data in enumerate(inputs[0]): + file_name = f"{args.artifact}/input_t_0_{idx}.bin" + data.detach().numpy().tofile(file_name) + golden = t_model(inputs[0]) + for idx, data in enumerate(golden): + file_name = f"{args.artifact}/golden_t_0_{idx}.bin" + data.detach().numpy().tofile(file_name) + + # build Predictor + print("Build Predictor") + predictor = EmformerRnntPredictorModel() + p_model = predictor.get_eager_model() + inputs = predictor.get_example_inputs() + pte_filename = 'emformer_rnnt_p_mtk' + build_executorch_binary( + p_model.eval(), + inputs, + f"{args.artifact}/{pte_filename}", + [inputs], + quant_dtype=Precision.A8W8, + skip_op_name={ + "aten_permute_copy_default", + "aten_embedding_default", + }, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list_p.txt" + with open(input_list_file, "w") as f: + f.write("input_p_0_0.bin input_p_0_1.bin input_p_0_2.bin") + f.flush() + for idx, data in enumerate(inputs[0]): + file_name = f"{args.artifact}/input_p_0_{idx}.bin" + try: + data.detach().numpy().tofile(file_name) + except: + pass + golden = p_model(inputs[0]) + for idx, data in enumerate(golden): + file_name = f"{args.artifact}/golden_p_0_{idx}.bin" + try: + data.detach().numpy().tofile(file_name) + except: + pass + + # build Joiner + print("Build Joiner") + joiner = EmformerRnntJoinerModel() + j_model = joiner.get_eager_model() + inputs = joiner.get_example_inputs() + pte_filename = 'emformer_rnnt_j_mtk' + build_executorch_binary( + j_model.eval(), + inputs, + f"{args.artifact}/{pte_filename}", + [inputs], + quant_dtype=Precision.A8W8, + skip_op_name={ + "aten_add_tensor", + }, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list_j.txt" + with open(input_list_file, "w") as f: + f.write("input_j_0_0.bin input_j_0_1.bin input_j_0_2.bin input_j_0_3.bin") + f.flush() + for idx, data in enumerate(inputs[0]): + file_name = f"{args.artifact}/input_j_0_{idx}.bin" + data.detach().numpy().tofile(file_name) + golden = j_model(inputs[0]) + for idx, data in enumerate(golden): + file_name = f"{args.artifact}/golden_j_0_{idx}.bin" + data.detach().numpy().tofile(file_name) + diff --git a/examples/mediatek/model_export_scripts/inception_v3.py b/examples/mediatek/model_export_scripts/inception_v3.py index c28bd85b402..5bc8e99e09e 100755 --- a/examples/mediatek/model_export_scripts/inception_v3.py +++ b/examples/mediatek/model_export_scripts/inception_v3.py @@ -6,10 +6,14 @@ import argparse import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) import torch from executorch.backends.mediatek import Precision -from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( +from aot_utils.oss_utils.utils import ( build_executorch_binary, ) from executorch.examples.models.inception_v3 import InceptionV3Model diff --git a/examples/mediatek/model_export_scripts/inception_v4.py b/examples/mediatek/model_export_scripts/inception_v4.py index ccb2ce16f22..a1ded356e50 100755 --- a/examples/mediatek/model_export_scripts/inception_v4.py +++ b/examples/mediatek/model_export_scripts/inception_v4.py @@ -6,10 +6,14 @@ import argparse import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) import torch from executorch.backends.mediatek import Precision -from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( +from aot_utils.oss_utils.utils import ( build_executorch_binary, ) from executorch.examples.models.inception_v4 import InceptionV4Model diff --git a/examples/mediatek/model_export_scripts/mobilebert.py b/examples/mediatek/model_export_scripts/mobilebert.py new file mode 100755 index 00000000000..c7122a0484c --- /dev/null +++ b/examples/mediatek/model_export_scripts/mobilebert.py @@ -0,0 +1,69 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) +import json +import os +import numpy as np +import argparse + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.models.mobilebert import MobileBertModelExample +from aot_utils.oss_utils.utils import ( + build_executorch_binary, + make_output_dir, +) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./mobilebert", + default="./mobilebert", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "mobilebert_mtk" + model = MobileBertModelExample() + instance = model.get_eager_model().eval() + inputs = model.get_example_inputs() + build_executorch_binary( + instance.eval(), + inputs, + f"{args.artifact}/{pte_filename}", + [inputs], + quant_dtype=Precision.A8W8, + skip_op_name={ + "aten_embedding_default", + }, + ) + + #save data to inference on device + golden = instance(inputs[0]) + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write("input_0_0.bin") + f.flush() + file_name = f"{args.artifact}/input_0_0.bin" + inputs[0].detach().numpy().tofile(file_name) + file_name = f"{args.artifact}/golden_0_0.bin" + for idx, data in enumerate(golden): + file_name = f"{args.artifact}/golden_0_{idx}.bin" + data.detach().numpy().tofile(file_name) diff --git a/examples/mediatek/model_export_scripts/mobilenet_v2.py b/examples/mediatek/model_export_scripts/mobilenet_v2.py index 97f2ed884eb..ff754814a45 100755 --- a/examples/mediatek/model_export_scripts/mobilenet_v2.py +++ b/examples/mediatek/model_export_scripts/mobilenet_v2.py @@ -6,10 +6,14 @@ import argparse import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) import torch from executorch.backends.mediatek import Precision -from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( +from aot_utils.oss_utils.utils import ( build_executorch_binary, ) from executorch.examples.models.mobilenet_v2 import MV2Model diff --git a/examples/mediatek/model_export_scripts/mobilenet_v3.py b/examples/mediatek/model_export_scripts/mobilenet_v3.py index fed2497ca26..aeb75e7526d 100755 --- a/examples/mediatek/model_export_scripts/mobilenet_v3.py +++ b/examples/mediatek/model_export_scripts/mobilenet_v3.py @@ -6,10 +6,14 @@ import argparse import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) import torch from executorch.backends.mediatek import Precision -from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( +from aot_utils.oss_utils.utils import ( build_executorch_binary, ) from executorch.examples.models.mobilenet_v3 import MV3Model diff --git a/examples/mediatek/model_export_scripts/resnet18.py b/examples/mediatek/model_export_scripts/resnet18.py index 2f3af57e7f3..ba2d103efef 100755 --- a/examples/mediatek/model_export_scripts/resnet18.py +++ b/examples/mediatek/model_export_scripts/resnet18.py @@ -6,10 +6,14 @@ import argparse import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) import torch from executorch.backends.mediatek import Precision -from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( +from aot_utils.oss_utils.utils import ( build_executorch_binary, ) from executorch.examples.models.resnet import ResNet18Model diff --git a/examples/mediatek/model_export_scripts/resnet50.py b/examples/mediatek/model_export_scripts/resnet50.py index ce23842447b..ad43cb5d943 100755 --- a/examples/mediatek/model_export_scripts/resnet50.py +++ b/examples/mediatek/model_export_scripts/resnet50.py @@ -6,10 +6,14 @@ import argparse import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) import torch from executorch.backends.mediatek import Precision -from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( +from aot_utils.oss_utils.utils import ( build_executorch_binary, ) from executorch.examples.models.resnet import ResNet50Model diff --git a/examples/mediatek/model_export_scripts/vit_b_16.py b/examples/mediatek/model_export_scripts/vit_b_16.py new file mode 100755 index 00000000000..be43a938333 --- /dev/null +++ b/examples/mediatek/model_export_scripts/vit_b_16.py @@ -0,0 +1,95 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) +import json +import numpy as np +import argparse + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.models.torchvision_vit import TorchVisionViTModel +from aot_utils.oss_utils.utils import ( + build_executorch_binary, + make_output_dir, +) + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.vit_b_16 = TorchVisionViTModel().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.vit_b_16(nchw_input1) + return output + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./vit_b_16", + default="./vit_b_16", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "vit_b_16_mtk" + instance = NhwcWrappedModel() + + # if dropout.p = 0, change probability to 1e-6 to prevent -inf when quantize + for name, module in instance.named_modules(): + if type(module) == torch.nn.Dropout: + if module.p == 0: + module.p = 1e-6 + + inputs = (torch.randn(1, 224, 224, 3),) + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + [inputs], + quant_dtype=Precision.A8W8, + skip_op_name = { + "aten_permute_copy_default_4", + "aten_permute_copy_default_18", + "aten_permute_copy_default_32", + "aten_permute_copy_default_46", + "aten_permute_copy_default_60", + "aten_permute_copy_default_74", + "aten_permute_copy_default_88", + "aten_permute_copy_default_102", + "aten_permute_copy_default_116", + "aten_permute_copy_default_130", + "aten_permute_copy_default_144", + "aten_permute_copy_default_158", + }, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write("input_0_0.bin") + f.flush() + file_name = f"{args.artifact}/input_0_0.bin" + inputs[0].detach().numpy().tofile(file_name) + file_name = f"{args.artifact}/golden_0_0.bin" + golden = instance(inputs[0]) + golden.detach().numpy().tofile(file_name) + diff --git a/examples/mediatek/model_export_scripts/wav2letter.py b/examples/mediatek/model_export_scripts/wav2letter.py new file mode 100755 index 00000000000..f999ec0f55b --- /dev/null +++ b/examples/mediatek/model_export_scripts/wav2letter.py @@ -0,0 +1,71 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys + +if os.getcwd() not in sys.path: + sys.path.append(os.getcwd()) +import json +import numpy as np +import argparse + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.models.wav2letter import Wav2LetterModel +from aot_utils.oss_utils.utils import ( + build_executorch_binary, + make_output_dir, +) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./wav2letter", + default="./wav2letter", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + # build pte + pte_filename = "wav2letter_mtk" + model = Wav2LetterModel() + instance = model.get_eager_model() + inputs = model.get_example_inputs() + + build_executorch_binary( + instance.eval(), + inputs, + f"{args.artifact}/{pte_filename}", + [inputs], + quant_dtype=Precision.A8W8, + skip_op_name = { + "aten_convolution_default", + "aten_convolution_default_1", + "aten_convolution_default_9", + "aten__log_softmax_default", + }, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write("input_0_0.bin") + f.flush() + file_name = f"{args.artifact}/input_0_0.bin" + inputs[0].detach().numpy().tofile(file_name) + file_name = f"{args.artifact}/golden_0_0.bin" + golden = instance(inputs[0]) + golden.detach().numpy().tofile(file_name) + diff --git a/examples/mediatek/mtk_build_examples.sh b/examples/mediatek/mtk_build_examples.sh index df70489cf2a..87787e373d3 100755 --- a/examples/mediatek/mtk_build_examples.sh +++ b/examples/mediatek/mtk_build_examples.sh @@ -39,6 +39,7 @@ main() { -DBUCK2="$BUCK_PATH" \ -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-26 \ -DANDROID_NATIVE_API_LEVEL=23 \ -DEXECUTORCH_BUILD_NEURON=ON \ -DNEURON_BUFFER_ALLOCATOR_LIB="$NEURON_BUFFER_ALLOCATOR_LIB" \ @@ -58,6 +59,7 @@ main() { cmake -DCMAKE_PREFIX_PATH="${cmake_prefix_path}" \ -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-26 \ -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ -DNEURON_BUFFER_ALLOCATOR_LIB="$NEURON_BUFFER_ALLOCATOR_LIB" \ -B"${example_build_dir}" \ diff --git a/examples/mediatek/shell_scripts/export_oss.sh b/examples/mediatek/shell_scripts/export_oss.sh index 3da5dc41f94..3e00506c05a 100755 --- a/examples/mediatek/shell_scripts/export_oss.sh +++ b/examples/mediatek/shell_scripts/export_oss.sh @@ -26,4 +26,19 @@ then elif [ $model = "resnet50" ] then python3 model_export_scripts/resnet50.py -d PATH_TO_DATASET +elif [ $model = "dcgan" ] +then + python3 model_export_scripts/dcgan.py +elif [ $model = "wav2letter" ] +then + python3 model_export_scripts/wav2letter.py +elif [ $model = "vit_b_16" ] +then + python3 model_export_scripts/vit_b_16.py +elif [ $model = "mobilebert" ] +then + python3 model_export_scripts/mobilebert.py +elif [ $model = "emformer_rnnt" ] +then + python3 model_export_scripts/emformer_rnnt.py fi