diff --git a/extension/training/examples/CIFAR/README.md b/extension/training/examples/CIFAR/README.md new file mode 100644 index 00000000000..0b05bb653aa --- /dev/null +++ b/extension/training/examples/CIFAR/README.md @@ -0,0 +1,58 @@ +# CIFAR 10 End-to-End Fine-Tuning Tutorial + +## Objective: + +This tutorial guides the users through the training process of a simple PyTorch CNN model on the server and subsequently fine-tune the model on their edge devices. + +### Key Objectives + +1. **Server-Side Training**: Users can leverage the computational resource of the server to perform initial model training using PyTorch. +2. **Edge Device Fine-Tuning**: Pre-trained models are lowered and deployed on mobile devices through ExecuTorch where they undergo fine-tuning. +3. **Performance Benchmarking**: To track comprehensive performance metrics for on-device fine-tuning operations, measuring factors such as training speed, memory usage, and model accuracy to evaluate ExecuTorch's effectiveness in the edge environment. + +## ExecuTorch Environment Setup + +For easier management of Python environments and packages, we recommended using a Python environment management tool such as `conda`, `venv`, or `uv`. In this demonstration, we will use `uv` to set up the Python environment. + +To install ExecuTorch in a [`uv`](https://docs.astral.sh/uv/getting-started/installation/) Python environment use the following commands: + +```bash +$ git clone https://github.com/pytorch/executorch.git --recurse-submodules +$ cd executorch +$ uv venv --seed --prompt et --python 3.10 +$ source .venv/bin/activate +$ git fetch origin +$ git submodule sync --recursive +$ git submodule update --init --recursive +$ ./install_executorch.sh +``` + +## Data Preparation + +We can download the CIFAR-10 dataset from the [official website](https://www.cs.toronto.edu/~kriz/cifar.html) and extract it to the desired location. Alternatively, we can also use the following command to download, extract, and create a balanced dataset: + +```bash +python data_utils.py --train-data-batch-path ./data/cifar-10/cifar-10-batches-py/data_batch_1 --train-output-path ./data/cifar-10/extracted_data/train_data.bin --test-data-batch-path ./data/cifar-10/cifar-10-batches-py/test_batch --test-output-path ./data/cifar-10/extracted_data/test_data.bin --train-images-per-class 100 +``` + +## Model Export + +Alternatively, if the users have a pre-trained pytorch model, they can export the standalone `pte`file using the following command: + +```bash +python export.py --train-model-path cifar10_model.pth --pte-only-model-path cifar10_model.pte +``` + +For getting the `pte` and `ptd` files, they can use the following command: + +```bash +python export.py --train-model-path cifar10_model.pth --with-ptd --pte-model-path cifar10_model.pte --ptd-model-path . +``` + +## Model Training and Fine-Tuning + +To run the end-to-end example, the users can use the following command: + +```bash +python main.py --data-dir ./data --model-path cifar10_model.pth --pte-model-path cifar10_model.pte --split-pte-model-path cifar10_model_pte_only.pte --save-pt-json cifar10_pt.json --save-et-json cifar10_et.json --ptd-model-dir . --epochs 1 --fine-tune-epochs 1 +``` diff --git a/extension/training/examples/CIFAR/data_utils.py b/extension/training/examples/CIFAR/data_utils.py new file mode 100644 index 00000000000..e683581ab8a --- /dev/null +++ b/extension/training/examples/CIFAR/data_utils.py @@ -0,0 +1,389 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import argparse +import os +import pickle +import typing +from collections import defaultdict + +import numpy as np +import torch +import torchvision +from PIL import Image +from torch.utils.data import DataLoader, Dataset, Subset + + +class BalancedCIFARDataset(Dataset): + """ + Custom dataset class to load balanced + CIFAR-10 data from binary file. + """ + + def __init__( + self, + data_path: str, + transform: typing.Optional[torchvision.transforms.Compose] = None, + ) -> None: + """ + Args: + data_path: Path to the balanced dataset binary file + transform: Optional transformation to be applied on a sample + """ + self.data = [] + self.labels = [] + + # Read binary format: 1 byte label + 3072 bytes image data per record + with open(data_path, "rb") as f: + while True: + # Read label (1 byte) + label_byte = f.read(1) + if not label_byte: # End of file + break + label = int.from_bytes(label_byte, byteorder="big") + + # Read image data (3 * 32 * 32 = 3072 bytes) + image_bytes = f.read(3072) + if len(image_bytes) != 3072: + break # Incomplete record + + # Convert bytes to numpy array + image_data = np.frombuffer(image_bytes, dtype=np.uint8) + + self.data.append(image_data) + self.labels.append(label) + + self.data = np.array(self.data) + self.labels = np.array(self.labels) + self.transform = transform + + print(f"Loaded {len(self.data)} images from {data_path}") + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx: int) -> typing.Tuple[Image.Image, int]: + # Reshape from (3072,) to (32, 32, 3) and convert to PIL Image + image_data = self.data[idx].reshape(3, 32, 32).transpose(1, 2, 0) + image = Image.fromarray(image_data) + label = self.labels[idx] + + if self.transform: + image = self.transform(image) + + return image, label + + +def create_balanced_cifar_dataset( + data_batch_path: str = "./data/cifar-10/cifar-10-batches-py/data_batch_1", + output_path: str = "./data/cifar-10/extracted_data/train_data.bin", + images_per_class: int = 100, +) -> str: + """ + Reads CIFAR-10 data from data_batch_1 file and creates a balanced dataset + with specified number of images per class, saved in binary format + compatible with Android. + + Args: + data_batch_path: Path to the CIFAR-10 data_batch_1 file + output_path: Path where the balanced dataset will be saved + images_per_class: Number of images to extract per class (default: 100) + """ + # Load the CIFAR-10 data batch + with open(data_batch_path, "rb") as f: + data_dict = pickle.load(f, encoding="bytes") + + # Extract data and labels + data = data_dict[b"data"] # Shape: (10000, 3072) + labels = data_dict[b"labels"] # List of 10000 labels + + # Group images by class + class_images = defaultdict(list) + class_labels = defaultdict(list) + + for i, label in enumerate(labels): + if len(class_images[label]) < images_per_class: + class_images[label].append(data[i]) + class_labels[label].append(label) + + # Combine all selected images and labels + selected_data = [] + selected_labels = [] + + for class_id in range(10): # CIFAR-10 has 10 classes (0-9) + if class_id in class_images: + selected_data.extend(class_images[class_id]) + selected_labels.extend(class_labels[class_id]) + print( + f"Class {class_id}: " f"{len(class_images[class_id])} images selected" + ) + + # Convert to numpy arrays + selected_data = np.array(selected_data, dtype=np.uint8) + selected_labels = np.array(selected_labels, dtype=np.uint8) + + # Ensure the output directory exists + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Save in binary format compatible with Android CIFAR-10 reader + # Format: 1 byte label + 3072 bytes image data per record + with open(output_path, "wb") as f: + for i in range(len(selected_data)): + # Write label as single byte + f.write(bytes([selected_labels[i]])) + # Write image data (3072 bytes) + f.write(selected_data[i].tobytes()) + + print(f"Balanced dataset saved to {output_path}") + print(f"Total images: {len(selected_data)}") + print(f"File size: {os.path.getsize(output_path)} bytes") + print(f"Expected size: {len(selected_data) * (1 + 3072)} bytes") + return output_path + + +def get_data_loaders( + batch_size: int = 4, + num_workers: int = 2, + data_dir: str = "./data", + use_balanced_dataset: bool = True, + images_per_class: int = 100, +) -> typing.Tuple[DataLoader, DataLoader]: + """ + Create data loaders for training, validation, and testing. + + Args: + batch_size: Batch size for data loaders + num_workers: Number of worker processes for data loading + data_dir: Root directory for data + use_balanced_dataset: Whether to use balanced dataset or + standard CIFAR-10 + images_per_class: Number of images per class for balanced dataset + """ + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.RandomCrop(32, padding=4), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ), + ] + ) + + if use_balanced_dataset: + # Download CIFAR-10 first to ensure the raw data exists + print("Downloading CIFAR-10 dataset...") + torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True) + torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True) + + # The actual path where torchvision stores CIFAR-10 data + cifar_data_dir = os.path.join(data_dir, "cifar-10-batches-py") + + # Create balanced dataset if it doesn't exist + balanced_data_path = os.path.join( + data_dir, "cifar-10/extracted_data/train_data.bin" + ) + data_batch_path = os.path.join(cifar_data_dir, "data_batch_1") + + # Ensure the output directory exists + os.makedirs(os.path.dirname(balanced_data_path), exist_ok=True) + + # Create balanced dataset if it doesn't exist + if not os.path.exists(balanced_data_path): + print("Creating balanced train dataset...") + create_balanced_cifar_dataset( + data_batch_path=data_batch_path, + output_path=balanced_data_path, + images_per_class=images_per_class, + ) + + # Use balanced dataset for training + trainset = BalancedCIFARDataset(balanced_data_path, transform=transforms) + + indices = torch.randperm(len(trainset)).tolist() + + train_subset = Subset(trainset, indices) + + balanced_test_data_path = os.path.join( + data_dir, "cifar-10/extracted_data/test_data.bin" + ) + test_data_batch_path = os.path.join(cifar_data_dir, "test_batch") + # Ensure the output directory exists + os.makedirs(os.path.dirname(balanced_test_data_path), exist_ok=True) + # Create balanced dataset if it doesn't exist + if not os.path.exists(balanced_test_data_path): + print("Creating balanced test dataset...") + create_balanced_cifar_dataset( + data_batch_path=test_data_batch_path, + output_path=balanced_test_data_path, + images_per_class=images_per_class, + ) + # Use balanced dataset for testing + test_set = BalancedCIFARDataset(balanced_test_data_path, transform=transforms) + + else: + # Use standard CIFAR-10 dataset + trainset = torchvision.datasets.CIFAR10( + root=data_dir, train=True, download=True, transform=transforms + ) + + train_set_indices = torch.randperm(len(trainset)).tolist() + + train_subset = Subset(trainset, train_set_indices) + + # Test set always uses standard CIFAR-10 + test_set = torchvision.datasets.CIFAR10( + root=data_dir, train=False, download=True, transform=transforms + ) + + train_loader = DataLoader( + train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + + test_loader = DataLoader( + test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + return train_loader, test_loader + + +def count_images_per_class(loader: DataLoader) -> typing.Dict[int, int]: + """ + Count the number of images per class in a DataLoader. + + This function iterates through a DataLoader and counts how many images + belong to each class based on their labels. + + Args: + loader (DataLoader): The DataLoader containing image-label pairs + + Returns: + Dict[int, int]: A dictionary mapping class IDs to their counts + """ + class_counts = defaultdict(int) + for _, labels in loader: + for label in labels: + class_counts[label.item()] += 1 + return class_counts + + +def parse_args() -> argparse.Namespace: + """ + Parse command line arguments for the CIFAR-10 training script. + + This function sets up an argument parser with various configuration options + for training a CIFAR-10 model with ExecutorTorch, including data paths, + training hyperparameters, and model save locations. + + Returns: + argparse.Namespace: An object containing all the parsed command line + arguments with their respective values (either user-provided or + defaults). + + """ + parser = argparse.ArgumentParser(description="CIFAR-10 Data Preparation Example") + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size for data loaders (default: 4)", + ) + + parser.add_argument( + "--num-workers", + type=int, + default=2, + help="Number of worker processes for data loading (default: 2)", + ) + + parser.add_argument( + "--data-dir", + type=str, + default="./data", + help="Directory to download CIFAR-10 dataset (default: ./data)", + ) + + parser.add_argument( + "--use-balanced-dataset", + action="store_true", + default=True, + help="Use balanced dataset instead of full CIFAR-10 (default: True)", + ) + + parser.add_argument( + "--train-data-batch-path", + type=str, + default="./data/cifar-10/cifar-10-batches-py/data_batch_1", + help="Directory for cifar-10-batches-py", + ) + + parser.add_argument( + "--train-output-path", + type=str, + default="./data/cifar-10/extracted_data/train_data.bin", + help="Directory for saving the train_data.bin", + ) + + parser.add_argument( + "--test-data-batch-path", + type=str, + default="./data/cifar-10/cifar-10-batches-py/test_batch_1", + help="Directory for cifar-10-batches-py", + ) + + parser.add_argument( + "--test-output-path", + type=str, + default="./data/cifar-10/extracted_data/train_data.bin", + help="Directory for saving the train_data.bin", + ) + + parser.add_argument( + "--train-images-per-class", + type=int, + default=100, + help="Number of images per class for balanced dataset (default: 100 and max: 1000)", + ) + + return parser.parse_args() + + +def main() -> None: + """ + Utility function to demonstrate data loading and class distribution analysis. + + This function creates data loaders for CIFAR-10 dataset using the get_data_loaders + function, then counts and prints the number of images per class in both the + training and test datasets to verify balanced distribution. + + Returns: + None + """ + + args = parse_args() + + # Create data loaders + train_loader, test_loader = get_data_loaders( + batch_size=args.batch_size, + data_dir=args.data_dir, + use_balanced_dataset=args.use_balanced_dataset, + images_per_class=args.train_images_per_class, + ) + + # Count images per class + class_counts = count_images_per_class(train_loader) + + print("Class counts in train dataset:", class_counts) + + class_counts = count_images_per_class(test_loader) + + print("Class counts in test dataset:", class_counts) + + +if __name__ == "__main__": + main() diff --git a/extension/training/examples/CIFAR/export.py b/extension/training/examples/CIFAR/export.py new file mode 100644 index 00000000000..ea388019864 --- /dev/null +++ b/extension/training/examples/CIFAR/export.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import argparse + +import torch +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from executorch.extension.training.examples.CIFAR.data_utils import get_data_loaders +from executorch.extension.training.examples.CIFAR.model import ( + CIFAR10Model, + ModelWithLoss, +) +from torch.export import export +from torch.export.experimental import _export_forward_backward + + +def export_model_combined( + net: torch.nn.Module, + input_tensor: torch.Tensor, + label_tensor: torch.Tensor, + with_external_tensor_data: bool = False, +) -> ExecuTorchModule: + """ + Export a PyTorch model to an ExecutorTorch module format, optionally with external tensor data. + + This function takes a PyTorch model and sample input/label tensors, + wraps the model with a loss function, exports it using torch.export, + applies forward-backward pass optimization, converts it to edge format, + and finally to ExecutorTorch format. If with_external_tensor_data is True, + the model will be exported with external constants and mutable weights. + + TODO: set dynamic shape for the batch size here. + + Args: + net (torch.nn.Module): The PyTorch model to be exported + input_tensor (torch.Tensor): A sample input tensor with the correct shape + label_tensor (torch.Tensor): A sample label tensor with the correct shape + with_external_tensor_data (bool, optional): Whether to export with external tensor data. + Defaults to False. + + Returns: + ExecuTorchModule: The exported model in ExecutorTorch format ready for deployment + """ + criterion = torch.nn.CrossEntropyLoss() + model_with_loss = ModelWithLoss(net, criterion) + ep = export(model_with_loss, (input_tensor, label_tensor), strict=True) + ep = _export_forward_backward(ep) + ep = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) + + if with_external_tensor_data: + ep = ep.to_executorch( + config=ExecutorchBackendConfig( + external_constants=True, # This is the flag that + # enables the external constants to be stored in a + # separate file external to the PTE file. + external_mutable_weights=True, # This is the flag + # that enables all trainable weights will be stored + # in a separate file external to the PTE file. + ) + ) + else: + ep = ep.to_executorch() + + return ep + + +def get_pte_only(net: torch.nn.Module) -> ExecuTorchModule: + """ + Generate an ExecutorTorch module from a PyTorch model without external tensor data. + + This function retrieves a sample input and label tensor from the test data loader, + and uses them to export the given PyTorch model to an ExecutorTorch module format + without external constants or mutable weights. + + Args: + net (torch.nn.Module): The PyTorch model to be exported. + + Returns: + ExecuTorchModule: The exported model in ExecutorTorch format. + """ + _, test_loader = get_data_loaders() + # get a sample input and label tensor + validation_sample_data = next(iter(test_loader)) + sample_input, sample_label = validation_sample_data + return export_model_combined( + net, sample_input, sample_label, with_external_tensor_data=False + ) + + +def get_pte_with_ptd(net: torch.nn.Module) -> ExecuTorchModule: + """ + Generate an ExecutorTorch module from a PyTorch model with external tensor data. + + This function retrieves a sample input and label tensor from the test data loader, + and uses them to export the given PyTorch model to an ExecutorTorch module format + with external constants and mutable weights. + + Args: + net (torch.nn.Module): The PyTorch model to be exported. + + Returns: + ExecuTorchModule: The exported model in ExecutorTorch format with external tensor data. + """ + _, test_loader = get_data_loaders() + # get a sample input and label tensor + validation_sample_data = next(iter(test_loader)) + sample_input, sample_label = validation_sample_data + return export_model_combined( + net, sample_input, sample_label, with_external_tensor_data=True + ) + + +def export_model( + net: torch.nn.Module, + with_ptd: bool = False, +) -> ExecuTorchModule: + """ + Export a PyTorch model to ExecutorTorch format, optionally with external tensor data. + + This function is a high-level wrapper that handles getting sample data and + calling the appropriate export function based on the with_ptd flag. + + Args: + net (torch.nn.Module): The PyTorch model to be exported + with_ptd (bool, optional): Whether to export with external tensor data. + Defaults to False. + + Returns: + ExecuTorchModule: The exported model in ExecutorTorch format + """ + _, test_loader = get_data_loaders() + validation_sample_data = next(iter(test_loader)) + sample_input, sample_label = validation_sample_data + + return export_model_combined( + net, sample_input, sample_label, with_external_tensor_data=with_ptd + ) + + +def save_model(ep: ExecuTorchModule, model_path: str) -> None: + """ + Save an ExecutorTorch model to a specified file path. + + This function writes the buffer of an ExecutorTorchModule to a + file in binary format. + + Args: + ep (ExecuTorchModule): The ExecutorTorch module to be saved. + model_path (str): The file path where the model will be saved. + """ + with open(model_path, "wb") as file: + file.write(ep.buffer) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="CIFAR-10 Data Preparation Example") + parser.add_argument( + "--train-model-path", + type=str, + default="./cifar10_model.pth", + help="Path to the saved PyTorch model", + ) + parser.add_argument( + "--pte-only-model-path", + type=str, + default="./cifar10_pte_only_model.pte", + help="Path to the saved PTE only", + ) + parser.add_argument( + "--with-ptd", + action="store_true", + help="Whether to export the model with ptd", + ) + parser.add_argument( + "--pte-model-path", + type=str, + default="./cifar10_model.pte", + help="Path to the saved PTE", + ) + parser.add_argument( + "--ptd-model-path", + type=str, + default="./cifar10_model.ptd", + help="Path to the saved PTD", + ) + + return parser.parse_args() + + +def update_tensor_data_and_save(exported_program, ptd_model_path, pte_model_path): + exported_program._tensor_data["generic_cifar"] = exported_program._tensor_data.pop( + "_default_external_constant" + ) + exported_program.write_tensor_data_to_file(ptd_model_path) + save_model(exported_program, pte_model_path) + + +def main(): + args = parse_args() + net = CIFAR10Model() + state_dict = torch.load(args.train_model_path, weights_only=True) + net.load_state_dict(state_dict) + if args.with_ptd: + exported_program = get_pte_with_ptd(net) + update_tensor_data_and_save( + exported_program, args.ptd_model_path, args.pte_model_path + ) + else: + exported_program = get_pte_only(net) + save_model(exported_program, args.pte_only_model_path) + + +if __name__ == "__main__": + main() diff --git a/extension/training/examples/CIFAR/main.py b/extension/training/examples/CIFAR/main.py index 396dcf07593..c039cfa4ae8 100644 --- a/extension/training/examples/CIFAR/main.py +++ b/extension/training/examples/CIFAR/main.py @@ -8,110 +8,19 @@ import argparse -import torch -from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge -from executorch.extension.pybindings.portable_lib import ExecuTorchModule -from executorch.extension.training.examples.CIFAR.model import ( - CIFAR10Model, - ModelWithLoss, +from executorch.extension.training.examples.CIFAR.data_utils import get_data_loaders +from executorch.extension.training.examples.CIFAR.export import ( + get_pte_only, + get_pte_with_ptd, + save_model, + update_tensor_data_and_save, ) -from executorch.extension.training.examples.CIFAR.utils import ( - fine_tune_executorch_model, - get_data_loaders, +from executorch.extension.training.examples.CIFAR.model import CIFAR10Model +from executorch.extension.training.examples.CIFAR.train_utils import ( save_json, + train_both_models, train_model, ) -from torch.export import export -from torch.export.experimental import _export_forward_backward - - -def export_model( - net: torch.nn.Module, input_tensor: torch.Tensor, label_tensor: torch.Tensor -) -> ExecuTorchModule: - """ - Export a PyTorch model to an ExecutorTorch module format. - - This function takes a PyTorch model and sample input/label - tensors, wraps the model with a loss function, exports it - using torch.export, applies forward-backward pass - optimization, converts it to edge format, and finally to - ExecutorTorch format. - - Args: - net (torch.nn.Module): The PyTorch model to be exported - input_tensor (torch.Tensor): A sample input tensor with - the correct shape - label_tensor (torch.Tensor): A sample label tensor with - the correct shape - - Returns: - ExecuTorchModule: The exported model in ExecutorTorch - format ready for deployment - """ - criterion = torch.nn.CrossEntropyLoss() - model_with_loss = ModelWithLoss(net, criterion) - ep = export(model_with_loss, (input_tensor, label_tensor), strict=True) - ep = _export_forward_backward(ep) - ep = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) - ep = ep.to_executorch() - return ep - - -def export_model_with_ptd( - net: torch.nn.Module, input_tensor: torch.Tensor, label_tensor: torch.Tensor -) -> ExecuTorchModule: - """ - Export a PyTorch model to an ExecutorTorch module format with external - tensor data. - - This function takes a PyTorch model and sample input/label tensors, - wraps the model with a loss function, exports it using torch.export, - applies forward-backward pass optimization, converts it to edge format, - and finally to ExecutorTorch format with external constants and mutable - weights. - - Args: - net (torch.nn.Module): The PyTorch model to be exported - input_tensor (torch.Tensor): A sample input tensor with the correct - shape - label_tensor (torch.Tensor): A sample label tensor with the correct - shape - - Returns: - ExecuTorchModule: The exported model in ExecutorTorch format ready for - deployment - """ - criterion = torch.nn.CrossEntropyLoss() - model_with_loss = ModelWithLoss(net, criterion) - ep = export(model_with_loss, (input_tensor, label_tensor), strict=True) - ep = _export_forward_backward(ep) - ep = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) - ep = ep.to_executorch( - config=ExecutorchBackendConfig( - external_constants=True, # This is the flag that - # enables the external constants to be stored in a - # separate file external to the PTE file. - external_mutable_weights=True, # This is the flag - # that enables all trainable weights will be stored - # in a separate file external to the PTE file. - ) - ) - return ep - - -def save_model(ep: ExecuTorchModule, model_path: str) -> None: - """ - Save an ExecutorTorch model to a specified file path. - - This function writes the buffer of an ExecutorTorchModule to a - file in binary format. - - Args: - ep (ExecuTorchModule): The ExecutorTorch module to be saved. - model_path (str): The file path where the model will be saved. - """ - with open(model_path, "wb") as file: - file.write(ep.buffer) def parse_args() -> argparse.Namespace: @@ -212,6 +121,13 @@ def parse_args() -> argparse.Namespace: help="Learning rate for fine-tuning (default: 0.001)", ) + parser.add_argument( + "--momentum", + type=float, + default=0.9, + help="Momentum for fine-tuning (default: 0.9)", + ) + return parser.parse_args() @@ -241,48 +157,31 @@ def main() -> None: save_json(train_hist, args.save_pt_json) - # Export the model for et runtime - validation_sample_data = next(iter(test_loader)) - img, lbl = validation_sample_data - sample_input = img[0:1, :] - sample_label = lbl[0:1] - - ep = export_model(model, sample_input, sample_label) + ep = get_pte_only(model) save_model(ep, args.pte_model_path) - et_model, et_hist = fine_tune_executorch_model( - args.pte_model_path, - args.pte_model_path, - train_loader, - test_loader, + pytorch_model, et_mod, pytorch_history, et_history = train_both_models( + pytorch_model=model, + et_model_path=args.pte_model_path, + train_loader=train_loader, + test_loader=test_loader, epochs=args.fine_tune_epochs, - learning_rate=args.learning_rate, + lr=args.learning_rate, + momentum=args.momentum, + pytorch_save_path=args.model_path, ) - save_json(et_hist, args.save_et_json) + save_json(et_history, args.save_et_json) + save_json(pytorch_history, args.save_pt_json) # Split the model into the pte and ptd files - exported_program = export_model_with_ptd(model, sample_input, sample_label) + exported_program = get_pte_with_ptd(model) - exported_program._tensor_data["generic_cifar"] = exported_program._tensor_data.pop( - "_default_external_constant" + update_tensor_data_and_save( + exported_program, args.ptd_model_dir, args.split_pte_model_path ) - exported_program.write_tensor_data_to_file(args.ptd_model_dir) - save_model(exported_program, args.split_pte_model_path) - - # Finetune the PyTorch model - model, train_hist = train_model( - model, - train_loader, - test_loader, - epochs=args.fine_tune_epochs, - lr=args.learning_rate, - momentum=0.9, - save_path=args.model_path, - ) - - save_json(train_hist, args.save_pt_json) + print("\n\nProcess complete!!!\n\n") if __name__ == "__main__": diff --git a/extension/training/examples/CIFAR/targets.bzl b/extension/training/examples/CIFAR/targets.bzl index 3131f8e496d..786160d65b3 100644 --- a/extension/training/examples/CIFAR/targets.bzl +++ b/extension/training/examples/CIFAR/targets.bzl @@ -17,16 +17,59 @@ def define_common_targets(): ) runtime.python_library( - name = "utils", - srcs = ["utils.py"], + name = "data_utils", + srcs = ["data_utils.py"], + deps = [ + "//caffe2:torch", + "//pytorch/vision:torchvision", + ], + ) + + runtime.python_binary( + name = "data_processing", + srcs = ["data_utils.py"], + main_function = "executorch.extension.training.examples.CIFAR.data_utils.main", + deps = [ + "//caffe2:torch", + "//pytorch/vision:torchvision", + ], + ) + + runtime.python_library( + name = "train_utils", + srcs = ["train_utils.py"], visibility = [], # Private deps = [ "//caffe2:torch", "fbsource//third-party/pypi/tqdm:tqdm", + "//executorch/extension/pybindings:portable_lib", + "//executorch/extension/training:lib", + ], + ) + + runtime.python_binary( + name = "model_export", + srcs = ["export.py"], + main_function = "executorch.extension.training.examples.CIFAR.export.main", + deps = [ + ":model", + ":data_utils", "//caffe2:torch", "//executorch/exir:lib", - "//executorch/extension/pybindings:portable_lib", # @manual - "//pytorch/vision:torchvision", + "//executorch/extension/pybindings:portable_lib", + ], + ) + + runtime.python_library( + name = "export", + srcs = ["export.py"], + visibility = [], # Private + deps = [ + ":model", + ":data_utils", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/extension/pybindings:portable_lib", ], ) @@ -36,7 +79,9 @@ def define_common_targets(): main_function = "executorch.extension.training.examples.CIFAR.main.main", deps = [ ":model", - ":utils", + ":data_utils", + ":export", + ":train_utils", "fbsource//third-party/pypi/tqdm:tqdm", "//caffe2:torch", "//executorch/exir:lib", diff --git a/extension/training/examples/CIFAR/train.cpp b/extension/training/examples/CIFAR/train.cpp index f32f2537c9c..9539fccebd2 100644 --- a/extension/training/examples/CIFAR/train.cpp +++ b/extension/training/examples/CIFAR/train.cpp @@ -61,7 +61,7 @@ DEFINE_string( DEFINE_int32( batch_size, - 1, + 4, "Batch size for training."); // Batch size for training (must match // export batch size) diff --git a/extension/training/examples/CIFAR/train_utils.py b/extension/training/examples/CIFAR/train_utils.py new file mode 100644 index 00000000000..baed740d938 --- /dev/null +++ b/extension/training/examples/CIFAR/train_utils.py @@ -0,0 +1,624 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import json +import os +import time +import typing + +import torch +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, + ExecuTorchModule, +) +from executorch.extension.training import ( + _load_for_executorch_for_training_from_buffer, + get_sgd_optimizer, +) +from torch.utils.data import DataLoader +from tqdm import tqdm + + +def save_json( + history: typing.Dict[int, typing.Dict[str, float]], json_path: str +) -> str: + """ + Save training/validation history to a JSON file. + + This function takes a dictionary containing training/validation metrics + organized by epoch and saves it to a JSON file at the specified path. + + Args: + history (Dict[int, Dict[str, float]]): Dictionary with epoch numbers + as keys and dictionaries of metrics (loss, accuracy, etc.) as + values. + json_path (str): File path where the JSON file will be saved. + + Returns: + str: The path where the JSON file was saved. + """ + with open(json_path, "w") as f: + json.dump(history, f, indent=4) + print(f"History saved to {json_path}") + return json_path + + +def train_model( + model: torch.nn.Module, + train_loader: DataLoader, + test_loader: DataLoader, + epochs: int = 1, + lr: float = 0.001, + momentum: float = 0.9, + save_path: str = "./best_cifar10_model.pth", +) -> typing.Tuple[torch.nn.Module, typing.Dict[int, typing.Dict[str, float]]]: + """ + The train_model function takes a model, a train_loader, and the number of + epochs as input.It then trains the model on the training data for the + specified number of epochs using the SGD optimizer and a cross-entropy loss + function. The function returns the trained model. + + args: + model (Required): The model to be trained. + train_loader (tuple, Required): The training data loader. + test_loader (tuple, Optional): The testing data loader. + epochs (int, optional): The number of epochs to train the model. + lr (float, optional): The learning rate for the SGD optimizer. + momentum (float, optional): The momentum for the SGD optimizer. + save_path (str, optional): Path to save the best model. + """ + + history = {} + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum) + + # Initialize best testing loss to a high value for checkpointing + # on the best model + best_test_loss = float("inf") + + # Create directory for save_path if it doesn't exist + save_dir = os.path.dirname(save_path) + if save_dir and not os.path.exists(save_dir): + os.makedirs(save_dir) + + train_start_time = time.time() + # Training loop + for epoch in range(epochs): + model.train() + epoch_loss = 0.0 + epoch_correct = 0 + epoch_total = 0 + for data in train_loader: + # Get the input data as a list of [inputs, labels] + inputs, labels = data + + # Set the gradients to zero for the next backward pass + optimizer.zero_grad() + + # Forward + Backward pass and optimization + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # Calculate correct predictions for epoch statistics + _, predicted = torch.max(outputs.data, 1) + total = labels.size(0) + correct = (predicted == labels).sum().item() + + # Accumulate statistics for epoch summary + epoch_loss += loss.detach().item() + epoch_correct += correct + epoch_total += total + + train_end_time = time.time() + # Calculate the stats for average loss and accuracy for + # the entire epoch + avg_epoch_loss = epoch_loss / len(train_loader) + avg_epoch_accuracy = 100 * epoch_correct / epoch_total + print( + f"Epoch {epoch + 1}: Train Loss: {avg_epoch_loss:.4f}, " + f"Train Accuracy: {avg_epoch_accuracy:.2f}%" + ) + + test_start_time = time.time() + # Testing phase + if test_loader is not None: + model.eval() # Set model to evaluation mode + test_loss = 0.0 + test_correct = 0 + test_total = 0 + with torch.no_grad(): # No need to track gradients + for data in test_loader: + images, labels = data + outputs = model(images) + loss = criterion(outputs, labels) + test_loss += loss.detach().item() + + # Calculate Testing accuracy as well + _, predicted = torch.max(outputs.data, 1) + test_total += labels.size(0) + test_correct += (predicted == labels).sum().item() + + # Calculate average Testing loss and accuracy + avg_test_loss = test_loss / len(test_loader) + test_accuracy = 100 * test_correct / test_total + test_end_time = time.time() + print( + f"\t Testing Loss: {avg_test_loss:.4f}, " + f"Testing Accuracy: {test_accuracy:.2f}%" + ) + + # Save the model with the best Testing loss + if avg_test_loss < best_test_loss: + best_test_loss = avg_test_loss + torch.save(model.state_dict(), save_path) + print( + f"New best model saved with Testing loss: " + f"{avg_test_loss:.4f} and Testing accuracy: " + f"{test_accuracy:.2f}%" + ) + + history[epoch] = { + "train_loss": avg_epoch_loss, + "train_accuracy": avg_epoch_accuracy, + "testing_loss": avg_test_loss, + "testing_accuracy": test_accuracy, + "training_time": train_end_time - train_start_time, + "train_time_per_image": (train_end_time - train_start_time) + / epoch_total, + "testing_time": test_end_time - test_start_time, + "test_time_per_image": (test_end_time - test_start_time) / test_total, + } + + print("\nTraining Completed!\n") + print("\n###########SUMMARY#############\n") + print(f"Best Testing loss: {best_test_loss:.4f}") + print(f"Model saved at: {save_path}\n") + print("################################\n") + + return model, history + + +def fine_tune_executorch_model( + model_path: str, + save_path: str, + train_loader: DataLoader, + val_loader: DataLoader, + epochs: int = 10, + learning_rate: float = 0.001, + momentum: float = 0.9, +) -> tuple[ExecuTorchModule, typing.Dict[int, typing.Dict[str, float]]]: + """ + Fine-tune an ExecutorTorch model using a training and validation dataset. + + This function loads an ExecutorTorch model from a file, fine-tunes it using + the provided training data loader, and evaluates it on the validation data + loader. The function returns the fine-tuned model and a history dictionary + containing training and validation metrics. + + Args: + model_path (str): Path to the ExecutorTorch model file to be + fine-tuned. + save_path (str): Path where the fine-tuned model will be saved. + train_loader (DataLoader): DataLoader for the training dataset. + val_loader (DataLoader): DataLoader for the validation dataset. + epochs (int, optional): Number of epochs for fine-tuning. + learning_rate (float, optional): Learning rate for parameter + updates (default: 0.001). + momentum (float, optional): Momentum for parameter updates + (default: 0.9). + + Returns: + tuple: A tuple containing the fine-tuned ExecutorTorchModule + and a dictionary with training and validation metrics. + """ + with open(model_path, "rb") as f: + model_bytes = f.read() + et_mod = _load_for_executorch_from_buffer(model_bytes) + + grad_start = et_mod.run_method("__et_training_gradients_index_forward", [])[0] + param_start = et_mod.run_method("__et_training_parameters_index_forward", [])[0] + history = {} + + # Initialize momentum buffers for SGD with momentum + momentum_buffers = {} + + for epoch in range(epochs): + print(f"Epoch {epoch+1}/{epochs}") + epoch_loss = 0.0 + train_correct = 0 + train_total = 0 + train_start_time = time.time() + + for batch in tqdm(train_loader): + inputs, labels = batch + # Forward pass + out = et_mod.forward((inputs, labels), clone_outputs=False) + loss = out[0] + predicted = out[1] + epoch_loss += loss.item() + + # Calculate accuracy + train_correct += (predicted == labels).sum().item() + train_total += labels.size(0) + + # Update parameters using SGD with momentum + with torch.no_grad(): + for param_idx, (grad, param) in enumerate( + zip(out[grad_start:param_start], out[param_start:]) + ): + if momentum > 0: + # Initialize momentum buffer if not exists + if param_idx not in momentum_buffers: + momentum_buffers[param_idx] = torch.zeros_like(grad) + + # Update momentum buffer: v = momentum * v + grad + momentum_buffers[param_idx].mul_(momentum).add_(grad) + # Update parameter: param = param - lr * v + param.sub_(learning_rate * momentum_buffers[param_idx]) + else: + # Standard SGD without momentum + param.sub_(learning_rate * grad) + + train_end_time = time.time() + train_accuracy = 100 * train_correct / train_total if train_total != 0 else 0 + + avg_epoch_loss = epoch_loss / len(train_loader) + + # Evaluate on validation set + + val_loss = 0.0 + val_correct = 0 + val_total = 0 + val_samples = 100 # Limiting validation samples to 100 + val_start_time = time.time() + val_batches_processed = 0 + + for i, val_batch in tqdm(enumerate(val_loader)): + if i >= val_samples: + print(f"Reached {val_samples} batches for validation") + break + + inputs, labels = val_batch + val_batches_processed += 1 + + # Forward pass with full batch + out = et_mod.forward((inputs, labels), clone_outputs=False) + loss = out[0] + predicted = out[1] + val_loss += loss.item() + + # Calculate accuracy + val_correct += (predicted == labels).sum().item() + val_total += labels.size(0) + + val_end_time = time.time() + val_accuracy = 100 * val_correct / val_total if val_total != 0 else 0 + avg_val_loss = ( + val_loss / val_batches_processed if val_batches_processed > 0 else 0 + ) + + history[epoch] = { + "train_loss": avg_epoch_loss, + "train_accuracy": train_accuracy, + "validation_loss": avg_val_loss, + "validation_accuracy": val_accuracy, + "training_time": train_end_time - train_start_time, + "train_time_per_image": (train_end_time - train_start_time) / train_total, + "testing_time": val_end_time - val_start_time, + "test_time_per_image": (val_end_time - val_start_time) / val_total, + } + + return et_mod, history + + +def train_both_models( + pytorch_model: torch.nn.Module, + et_model_path: str, + train_loader: DataLoader, + test_loader: DataLoader, + epochs: int = 10, + lr: float = 0.001, + momentum: float = 0.9, + pytorch_save_path: str = "./best_cifar10_model.pth", + et_save_path: str = "./best_cifar10_et_model.pte", +) -> typing.Tuple[ + torch.nn.Module, + typing.Any, + typing.Dict[int, typing.Dict[str, float]], + typing.Dict[int, typing.Dict[str, float]], +]: + """ + Train both a PyTorch model and an ExecutorTorch model simultaneously using the same data. + + This function trains both models in parallel, using the same data batches for both, + which makes debugging and comparison easier. It tracks metrics for both models + and provides a comparison of their performance. + + Args: + pytorch_model (torch.nn.Module): The PyTorch model to be trained + et_model_path (str): Path to the ExecutorTorch model file + train_loader (DataLoader): DataLoader for the training dataset + test_loader (DataLoader): DataLoader for the testing/validation dataset + epochs (int, optional): Number of epochs for training. Defaults to 10. + lr (float, optional): Learning rate for parameter updates. Defaults to 0.001. + momentum (float, optional): Momentum for parameter updates. Defaults to 0.9. + pytorch_save_path (str, optional): Path to save the best PyTorch model. Defaults to "./best_cifar10_model.pth". + + Returns: + tuple: A tuple containing: + - The trained PyTorch model + - The trained ExecutorTorch model + - Dictionary with PyTorch training and validation metrics + - Dictionary with ExecutorTorch training and validation metrics + """ + # Load the ExecutorTorch model + with open(et_model_path, "rb") as f: + model_bytes = f.read() + et_mod = _load_for_executorch_for_training_from_buffer(model_bytes) + + # Initialize histories for both models + pytorch_history = {} + et_history = {} + + # Initialize criterion and optimizer for PyTorch model + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(pytorch_model.parameters(), lr=lr, momentum=momentum) + + # TODO: Fix "RuntimeError: Must call forward_backward before named_params. + # This will be fixed in a later version" + # Evaluating the model for 1 epoch to initialize the parameters and get unblocked for now + # get one batch of data for initialization + images, labels = next(iter(train_loader)) + # Forward pass + et_out = et_mod.forward_backward(method_name="forward", inputs=(images, labels)) + + et_model_optimizer = get_sgd_optimizer( + et_mod.named_parameters(), + lr, + momentum, + ) + + # Initialize best testing loss for checkpointing + best_pytorch_test_loss = float("inf") + best_et_test_loss = float("inf") + + # Create directories for save paths if they don't exist + for path in [pytorch_save_path]: + save_dir = os.path.dirname(path) + if save_dir and not os.path.exists(save_dir): + os.makedirs(save_dir) + + for epoch in range(epochs): + print(f"Epoch {epoch+1}/{epochs}") + + pytorch_model.train() + + # Initialize metrics for this epoch + pytorch_epoch_loss = 0.0 + pytorch_correct = 0 + pytorch_total = 0 + + et_epoch_loss = 0.0 + et_correct = 0 + et_total = 0 + + # Training loop + pytorch_train_time = 0.0 + et_train_time = 0.0 + + for batch in tqdm(train_loader, desc="Training"): + inputs, labels = batch + batch_size = labels.size(0) + + # ---- PyTorch model training ---- + pytorch_start_time = time.time() + + # Zero the gradients + optimizer.zero_grad() + + # Forward pass + pytorch_outputs = pytorch_model(inputs) + pytorch_loss = criterion(pytorch_outputs, labels) + + # Backward pass and optimization + pytorch_loss.backward() + optimizer.step() + + pytorch_end_time = time.time() + pytorch_train_time += pytorch_end_time - pytorch_start_time + + # Calculate accuracy + _, pytorch_predicted = torch.max(pytorch_outputs.data, 1) + pytorch_correct += (pytorch_predicted == labels).sum().item() + pytorch_total += batch_size + + # Accumulate loss + pytorch_epoch_loss += pytorch_loss.detach().item() + + # ---- ExecutorTorch model training ---- + et_start_time = time.time() + + # Forward pass + et_out = et_mod.forward_backward( + method_name="forward", inputs=(inputs, labels) + ) + et_loss = et_out[0] + et_predicted = et_out[1] + + # Backward pass and optimize using the ExecutorchProgramManager's step method + et_model_optimizer.step(et_mod.named_gradients()) + + et_end_time = time.time() + et_train_time += et_end_time - et_start_time + + # Calculate accuracy + et_correct += (et_predicted == labels).sum().item() + et_total += batch_size + + # Accumulate loss + et_epoch_loss += et_loss.item() + + # Calculate training metrics + avg_pytorch_train_loss = pytorch_epoch_loss / len(train_loader) + pytorch_train_accuracy = 100 * pytorch_correct / pytorch_total + + avg_et_train_loss = et_epoch_loss / len(train_loader) + et_train_accuracy = 100 * et_correct / et_total + + print( + f"PyTorch - Train Loss: {avg_pytorch_train_loss:.4f}, Train Accuracy: {pytorch_train_accuracy:.2f}%" + ) + print( + f"ExecutorTorch - Train Loss: {avg_et_train_loss:.4f}, Train Accuracy: {et_train_accuracy:.2f}%" + ) + + # Testing/Validation phase + pytorch_model.eval() + + pytorch_test_loss = 0.0 + pytorch_test_correct = 0 + pytorch_test_total = 0 + pytorch_test_time = 0.0 + + et_test_loss = 0.0 + et_test_correct = 0 + et_test_total = 0 + et_test_time = 0.0 + + with torch.no_grad(): + for batch in tqdm(test_loader, desc="Testing"): + inputs, labels = batch + batch_size = labels.size(0) + + # ---- PyTorch model testing ---- + pytorch_test_start = time.time() + + pytorch_outputs = pytorch_model(inputs) + pytorch_loss = criterion(pytorch_outputs, labels) + + pytorch_test_end = time.time() + pytorch_test_time += pytorch_test_end - pytorch_test_start + + pytorch_test_loss += pytorch_loss.item() + + # Calculate accuracy + _, pytorch_predicted = torch.max(pytorch_outputs.data, 1) + pytorch_test_correct += (pytorch_predicted == labels).sum().item() + pytorch_test_total += batch_size + + # ---- ExecutorTorch model testing ---- + et_test_start = time.time() + + et_out = et_mod.forward_backward( + method_name="forward", inputs=(inputs, labels) + ) + et_loss = et_out[0] + et_predicted = et_out[1] + + et_test_end = time.time() + et_test_time += et_test_end - et_test_start + + et_test_loss += et_loss.item() + et_test_correct += (et_predicted == labels).sum().item() + et_test_total += batch_size + + # Calculate testing metrics + avg_pytorch_test_loss = pytorch_test_loss / len(test_loader) + pytorch_test_accuracy = 100 * pytorch_test_correct / pytorch_test_total + + avg_et_test_loss = et_test_loss / len(test_loader) + et_test_accuracy = 100 * et_test_correct / et_test_total + + print( + f"PyTorch - Test Loss: {avg_pytorch_test_loss:.4f}, Test Accuracy: {pytorch_test_accuracy:.2f}%" + ) + print( + f"ExecutorTorch - Test Loss: {avg_et_test_loss:.4f}, Test Accuracy: {et_test_accuracy:.2f}%" + ) + + # Compare losses + loss_diff = abs(avg_pytorch_test_loss - avg_et_test_loss) + print(f"Loss Difference: {loss_diff:.6f}") + + # Save the best PyTorch model + if avg_pytorch_test_loss < best_pytorch_test_loss: + best_pytorch_test_loss = avg_pytorch_test_loss + torch.save(pytorch_model.state_dict(), pytorch_save_path) + print( + f"New best PyTorch model saved with test loss: {avg_pytorch_test_loss:.4f}" + ) + + # Save the best ExecutorTorch model + if avg_et_test_loss < best_et_test_loss: + best_et_test_loss = avg_et_test_loss + # Save the ExecutorTorch model + save_dir = os.path.dirname(et_save_path) + if save_dir and not os.path.exists(save_dir): + os.makedirs(save_dir) + print( + f"New best ExecutorTorch model with test loss: {avg_et_test_loss:.4f}" + ) + + # Store history for both models + pytorch_history[epoch] = { + "train_loss": avg_pytorch_train_loss, + "train_accuracy": pytorch_train_accuracy, + "test_loss": avg_pytorch_test_loss, + "test_accuracy": pytorch_test_accuracy, + } + + et_history[epoch] = { + "train_loss": avg_et_train_loss, + "train_accuracy": et_train_accuracy, + "test_loss": avg_et_test_loss, + "test_accuracy": et_test_accuracy, + } + + # Add timing information + pytorch_history[epoch].update( + { + "training_time": pytorch_train_time, + "train_time_per_image": pytorch_train_time / pytorch_total, + "testing_time": pytorch_test_time, + "test_time_per_image": pytorch_test_time / pytorch_test_total, + } + ) + + et_history[epoch].update( + { + "training_time": et_train_time, + "train_time_per_image": et_train_time / et_total, + "testing_time": et_test_time, + "test_time_per_image": et_test_time / et_test_total, + } + ) + + # Print timing comparison + print( + f"PyTorch training time: {pytorch_train_time:.4f}s, testing time: {pytorch_test_time:.4f}s" + ) + print( + f"ExecutorTorch training time: {et_train_time:.4f}s, testing time: {et_test_time:.4f}s" + ) + print(f"Training time ratio (ET/PT): {et_train_time/pytorch_train_time:.4f}") + print(f"Testing time ratio (ET/PT): {et_test_time/pytorch_test_time:.4f}") + + print("\nTraining Completed!\n") + print("\n###########SUMMARY#############\n") + print(f"Best PyTorch test loss: {best_pytorch_test_loss:.4f}") + print(f"Best ExecutorTorch test loss: {best_et_test_loss:.4f}") + print( + f"Final loss difference: {abs(best_pytorch_test_loss - best_et_test_loss):.6f}" + ) + print(f"PyTorch model saved at: {pytorch_save_path}") + print(f"ExecutorTorch model path: {et_save_path}") + print("################################\n") + + return pytorch_model, et_mod, pytorch_history, et_history diff --git a/extension/training/examples/CIFAR/utils.py b/extension/training/examples/CIFAR/utils.py deleted file mode 100644 index a0ccf15511f..00000000000 --- a/extension/training/examples/CIFAR/utils.py +++ /dev/null @@ -1,581 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import json -import os -import pickle -import time -import typing -from collections import defaultdict - -import numpy as np -import torch -import torchvision -from executorch.extension.pybindings.portable_lib import ( - _load_for_executorch_from_buffer, - ExecuTorchModule, -) -from PIL import Image -from torch.utils.data import DataLoader, Dataset, Subset -from tqdm import tqdm - - -class BalancedCIFARDataset(Dataset): - """ - Custom dataset class to load balanced - CIFAR-10 data from binary file. - """ - - def __init__( - self, - data_path: str, - transform: typing.Optional[torchvision.transforms.Compose] = None, - ) -> None: - """ - Args: - data_path: Path to the balanced dataset binary file - transform: Optional transformation to be applied on a sample - """ - self.data = [] - self.labels = [] - - # Read binary format: 1 byte label + 3072 bytes image data per record - with open(data_path, "rb") as f: - while True: - # Read label (1 byte) - label_byte = f.read(1) - if not label_byte: # End of file - break - label = int.from_bytes(label_byte, byteorder="big") - - # Read image data (3 * 32 * 32 = 3072 bytes) - image_bytes = f.read(3072) - if len(image_bytes) != 3072: - break # Incomplete record - - # Convert bytes to numpy array - image_data = np.frombuffer(image_bytes, dtype=np.uint8) - - self.data.append(image_data) - self.labels.append(label) - - self.data = np.array(self.data) - self.labels = np.array(self.labels) - self.transform = transform - - print(f"Loaded {len(self.data)} images from {data_path}") - - def __len__(self) -> int: - return len(self.data) - - def __getitem__(self, idx: int) -> typing.Tuple[Image.Image, int]: - # Reshape from (3072,) to (32, 32, 3) and convert to PIL Image - image_data = self.data[idx].reshape(3, 32, 32).transpose(1, 2, 0) - image = Image.fromarray(image_data) - label = self.labels[idx] - - if self.transform: - image = self.transform(image) - - return image, label - - -def create_balanced_cifar_dataset( - data_batch_path: str = "./data/cifar-10/cifar-10-batches-py/data_batch_1", - output_path: str = "./data/cifar-10/extracted_data/train_data.bin", - images_per_class: int = 100, -) -> str: - """ - Reads CIFAR-10 data from data_batch_1 file and creates a balanced dataset - with specified number of images per class, saved in binary format - compatible with Android. - - Args: - data_batch_path: Path to the CIFAR-10 data_batch_1 file - output_path: Path where the balanced dataset will be saved - images_per_class: Number of images to extract per class (default: 100) - """ - # Load the CIFAR-10 data batch - with open(data_batch_path, "rb") as f: - data_dict = pickle.load(f, encoding="bytes") - - # Extract data and labels - data = data_dict[b"data"] # Shape: (10000, 3072) - labels = data_dict[b"labels"] # List of 10000 labels - - # Group images by class - class_images = defaultdict(list) - class_labels = defaultdict(list) - - for i, label in enumerate(labels): - if len(class_images[label]) < images_per_class: - class_images[label].append(data[i]) - class_labels[label].append(label) - - # Combine all selected images and labels - selected_data = [] - selected_labels = [] - - for class_id in range(10): # CIFAR-10 has 10 classes (0-9) - if class_id in class_images: - selected_data.extend(class_images[class_id]) - selected_labels.extend(class_labels[class_id]) - print( - f"Class {class_id}: " f"{len(class_images[class_id])} images selected" - ) - - # Convert to numpy arrays - selected_data = np.array(selected_data, dtype=np.uint8) - selected_labels = np.array(selected_labels, dtype=np.uint8) - - # Ensure the output directory exists - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - # Save in binary format compatible with Android CIFAR-10 reader - # Format: 1 byte label + 3072 bytes image data per record - with open(output_path, "wb") as f: - for i in range(len(selected_data)): - # Write label as single byte - f.write(bytes([selected_labels[i]])) - # Write image data (3072 bytes) - f.write(selected_data[i].tobytes()) - - print(f"Balanced dataset saved to {output_path}") - print(f"Total images: {len(selected_data)}") - print(f"File size: {os.path.getsize(output_path)} bytes") - print(f"Expected size: {len(selected_data) * (1 + 3072)} bytes") - return output_path - - -def get_data_loaders( - batch_size: int = 4, - num_workers: int = 2, - data_dir: str = "./data", - use_balanced_dataset: bool = True, - images_per_class: int = 100, -) -> typing.Tuple[DataLoader, DataLoader]: - """ - Create data loaders for training, validation, and testing. - - Args: - batch_size: Batch size for data loaders - num_workers: Number of worker processes for data loading - data_dir: Root directory for data - use_balanced_dataset: Whether to use balanced dataset or - standard CIFAR-10 - images_per_class: Number of images per class for balanced dataset - """ - transforms = torchvision.transforms.Compose( - [ - torchvision.transforms.RandomCrop(32, padding=4), - torchvision.transforms.RandomHorizontalFlip(), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( - (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) - ), - ] - ) - - if use_balanced_dataset: - # Download CIFAR-10 first to ensure the raw data exists - print("Downloading CIFAR-10 dataset...") - torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True) - torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True) - - # The actual path where torchvision stores CIFAR-10 data - cifar_data_dir = os.path.join(data_dir, "cifar-10-batches-py") - - # Create balanced dataset if it doesn't exist - balanced_data_path = os.path.join( - data_dir, "cifar-10/extracted_data/train_data.bin" - ) - data_batch_path = os.path.join(cifar_data_dir, "data_batch_1") - - # Ensure the output directory exists - os.makedirs(os.path.dirname(balanced_data_path), exist_ok=True) - - # Create balanced dataset if it doesn't exist - if not os.path.exists(balanced_data_path): - print("Creating balanced train dataset...") - create_balanced_cifar_dataset( - data_batch_path=data_batch_path, - output_path=balanced_data_path, - images_per_class=images_per_class, - ) - - # Use balanced dataset for training - trainset = BalancedCIFARDataset(balanced_data_path, transform=transforms) - - indices = torch.randperm(len(trainset)).tolist() - - train_subset = Subset(trainset, indices) - - balanced_test_data_path = os.path.join( - data_dir, "cifar-10/extracted_data/test_data.bin" - ) - test_data_batch_path = os.path.join(cifar_data_dir, "test_batch") - # Ensure the output directory exists - os.makedirs(os.path.dirname(balanced_test_data_path), exist_ok=True) - # Create balanced dataset if it doesn't exist - if not os.path.exists(balanced_test_data_path): - print("Creating balanced test dataset...") - create_balanced_cifar_dataset( - data_batch_path=test_data_batch_path, - output_path=balanced_test_data_path, - images_per_class=images_per_class, - ) - # Use balanced dataset for testing - test_set = BalancedCIFARDataset(balanced_test_data_path, transform=transforms) - - else: - # Use standard CIFAR-10 dataset - trainset = torchvision.datasets.CIFAR10( - root=data_dir, train=True, download=True, transform=transforms - ) - - train_set_indices = torch.randperm(len(trainset)).tolist() - - train_subset = Subset(trainset, train_set_indices) - - # Test set always uses standard CIFAR-10 - test_set = torchvision.datasets.CIFAR10( - root=data_dir, train=False, download=True, transform=transforms - ) - - train_loader = DataLoader( - train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers - ) - - test_loader = DataLoader( - test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers - ) - - return train_loader, test_loader - - -def count_images_per_class(loader: DataLoader) -> typing.Dict[int, int]: - """ - Count the number of images per class in a DataLoader. - - This function iterates through a DataLoader and counts how many images - belong to each class based on their labels. - - Args: - loader (DataLoader): The DataLoader containing image-label pairs - - Returns: - Dict[int, int]: A dictionary mapping class IDs to their counts - """ - class_counts = defaultdict(int) - for _, labels in loader: - for label in labels: - class_counts[label.item()] += 1 - return class_counts - - -def save_json( - history: typing.Dict[int, typing.Dict[str, float]], json_path: str -) -> str: - """ - Save training/validation history to a JSON file. - - This function takes a dictionary containing training/validation metrics - organized by epoch and saves it to a JSON file at the specified path. - - Args: - history (Dict[int, Dict[str, float]]): Dictionary with epoch numbers - as keys and dictionaries of metrics (loss, accuracy, etc.) as - values. - json_path (str): File path where the JSON file will be saved. - - Returns: - str: The path where the JSON file was saved. - """ - with open(json_path, "w") as f: - json.dump(history, f, indent=4) - print(f"History saved to {json_path}") - return json_path - - -def train_model( - model: torch.nn.Module, - train_loader: DataLoader, - test_loader: DataLoader, - epochs: int = 1, - lr: float = 0.001, - momentum: float = 0.9, - save_path: str = "./best_cifar10_model.pth", -) -> typing.Tuple[torch.nn.Module, typing.Dict[int, typing.Dict[str, float]]]: - """ - The train_model function takes a model, a train_loader, and the number of - epochs as input.It then trains the model on the training data for the - specified number of epochs using the SGD optimizer and a cross-entropy loss - function. The function returns the trained model. - - args: - model (Required): The model to be trained. - train_loader (tuple, Required): The training data loader. - test_loader (tuple, Optional): The testing data loader. - epochs (int, optional): The number of epochs to train the model. - lr (float, optional): The learning rate for the SGD optimizer. - momentum (float, optional): The momentum for the SGD optimizer. - save_path (str, optional): Path to save the best model. - """ - - history = {} - criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum) - - # Initialize best testing loss to a high value for checkpointing - # on the best model - best_test_loss = float("inf") - - # Create directory for save_path if it doesn't exist - save_dir = os.path.dirname(save_path) - if save_dir and not os.path.exists(save_dir): - os.makedirs(save_dir) - - train_start_time = time.time() - # Training loop - for epoch in range(epochs): - model.train() - epoch_loss = 0.0 - epoch_correct = 0 - epoch_total = 0 - for data in train_loader: - # Get the input data as a list of [inputs, labels] - inputs, labels = data - - # Set the gradients to zero for the next backward pass - optimizer.zero_grad() - - # Forward + Backward pass and optimization - outputs = model(inputs) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - - # Calculate correct predictions for epoch statistics - _, predicted = torch.max(outputs.data, 1) - total = labels.size(0) - correct = (predicted == labels).sum().item() - - # Accumulate statistics for epoch summary - epoch_loss += loss.detach().item() - epoch_correct += correct - epoch_total += total - - train_end_time = time.time() - # Calculate the stats for average loss and accuracy for - # the entire epoch - avg_epoch_loss = epoch_loss / len(train_loader) - avg_epoch_accuracy = 100 * epoch_correct / epoch_total - print( - f"Epoch {epoch + 1}: Train Loss: {avg_epoch_loss:.4f}, " - f"Train Accuracy: {avg_epoch_accuracy:.2f}%" - ) - - test_start_time = time.time() - # Testing phase - if test_loader is not None: - model.eval() # Set model to evaluation mode - test_loss = 0.0 - test_correct = 0 - test_total = 0 - with torch.no_grad(): # No need to track gradients - for data in test_loader: - images, labels = data - outputs = model(images) - loss = criterion(outputs, labels) - test_loss += loss.detach().item() - - # Calculate Testing accuracy as well - _, predicted = torch.max(outputs.data, 1) - test_total += labels.size(0) - test_correct += (predicted == labels).sum().item() - - # Calculate average Testing loss and accuracy - avg_test_loss = test_loss / len(test_loader) - test_accuracy = 100 * test_correct / test_total - test_end_time = time.time() - print( - f"\t Testing Loss: {avg_test_loss:.4f}, " - f"Testing Accuracy: {test_accuracy:.2f}%" - ) - - # Save the model with the best Testing loss - if avg_test_loss < best_test_loss: - best_test_loss = avg_test_loss - torch.save(model.state_dict(), save_path) - print( - f"New best model saved with Testing loss: " - f"{avg_test_loss:.4f} and Testing accuracy: " - f"{test_accuracy:.2f}%" - ) - - history[epoch] = { - "train_loss": avg_epoch_loss, - "train_accuracy": avg_epoch_accuracy, - "testing_loss": avg_test_loss, - "testing_accuracy": test_accuracy, - "training_time": train_end_time - train_start_time, - "train_time_per_image": (train_end_time - train_start_time) - / epoch_total, - "testing_time": test_end_time - test_start_time, - "test_time_per_image": (test_end_time - test_start_time) / test_total, - } - - print("\nTraining Completed!\n") - print("\n###########SUMMARY#############\n") - print(f"Best Testing loss: {best_test_loss:.4f}") - print(f"Model saved at: {save_path}\n") - print("################################\n") - - return model, history - - -def fine_tune_executorch_model( - model_path: str, - save_path: str, - train_loader: DataLoader, - val_loader: DataLoader, - epochs: int = 10, - learning_rate: float = 0.001, - momentum: float = 0.9, -) -> tuple[ExecuTorchModule, typing.Dict[str, typing.Any]]: - """ - Fine-tune an ExecutorTorch model using a training and validation dataset. - - This function loads an ExecutorTorch model from a file, fine-tunes it using - the provided training data loader, and evaluates it on the validation data - loader. The function returns the fine-tuned model and a history dictionary - containing training and validation metrics. - - Args: - model_path (str): Path to the ExecutorTorch model file to be - fine-tuned. - save_path (str): Path where the fine-tuned model will be saved. - train_loader (DataLoader): DataLoader for the training dataset. - val_loader (DataLoader): DataLoader for the validation dataset. - epochs (int, optional): Number of epochs for fine-tuning. - learning_rate (float, optional): Learning rate for parameter - updates (default: 0.001). - momentum (float, optional): Momentum for parameter updates - (default: 0.9). - - Returns: - tuple: A tuple containing the fine-tuned ExecutorTorchModule - and a dictionary with training and validation metrics. - """ - with open(model_path, "rb") as f: - model_bytes = f.read() - et_mod = _load_for_executorch_from_buffer(model_bytes) - - grad_start = et_mod.run_method("__et_training_gradients_index_forward", [])[0] - param_start = et_mod.run_method("__et_training_parameters_index_forward", [])[0] - history = {} - - # Initialize momentum buffers for SGD with momentum - momentum_buffers = {} - - for epoch in range(epochs): - print(f"Epoch {epoch+1}/{epochs}") - epoch_loss = 0.0 - train_correct = 0 - train_total = 0 - train_start_time = time.time() - - for batch in tqdm(train_loader): - inputs, labels = batch - # Process each image-label pair individually - for i in range(len(inputs)): - input_image = inputs[ - i : i + 1 - ] # Use list slicing to extract single image - label = labels[i : i + 1] - # Forward pass - out = et_mod.forward((input_image, label), clone_outputs=False) - loss = out[0] - predicted = out[1] - epoch_loss += loss.item() - - # Calculate accuracy - if predicted.item() == label.item(): - train_correct += 1 - train_total += 1 - - # Update parameters using SGD with momentum - with torch.no_grad(): - for param_idx, (grad, param) in enumerate( - zip(out[grad_start:param_start], out[param_start:]) - ): - if momentum > 0: - # Initialize momentum buffer if not exists - if param_idx not in momentum_buffers: - momentum_buffers[param_idx] = torch.zeros_like(grad) - - # Update momentum buffer: v = momentum * v + grad - momentum_buffers[param_idx].mul_(momentum).add_(grad) - # Update parameter: param = param - lr * v - param.sub_(learning_rate * momentum_buffers[param_idx]) - else: - # Standard SGD without momentum - param.sub_(learning_rate * grad) - - train_end_time = time.time() - train_accuracy = 100 * train_correct / train_total if train_total != 0 else 0 - - avg_epoch_loss = epoch_loss / len(train_loader) / (train_loader.batch_size or 1) - - # Evaluate on validation set - - val_loss = 0.0 - val_correct = 0 - val_total = 0 - val_samples = 100 # Limiting validation samples to 100 - val_start_time = time.time() - - for i, val_batch in tqdm(enumerate(val_loader)): - if i == val_samples: - print(f"Reached {val_samples} samples for validation") - break - - inputs, labels = val_batch - - for i in range(len(inputs)): - input_image = inputs[ - i : i + 1 - ] # Use list slicing to extract single image - label = labels[i : i + 1] - # Forward pass - out = et_mod.forward((input_image, label), clone_outputs=False) - loss = out[0] - predicted = out[1] - val_loss += loss.item() - # Calculate accuracy - if predicted.item() == label.item(): - val_correct += 1 - val_total += 1 - - val_end_time = time.time() - val_accuracy = 100 * val_correct / val_total if val_total != 0 else 0 - avg_val_loss = val_loss / len(val_loader) - avg_val_loss /= val_loader.batch_size or 1 - - history[epoch] = { - "train_loss": avg_epoch_loss, - "train_accuracy": train_accuracy, - "validation_loss": avg_val_loss, - "validation_accuracy": val_accuracy, - "training_time": train_end_time - train_start_time, - "train_time_per_image": (train_end_time - train_start_time) / train_total, - "testing_time": val_end_time - val_start_time, - "test_time_per_image": (val_end_time - val_start_time) / val_total, - } - - return et_mod, history