diff --git a/python/infinicore/nn/__init__.py b/python/infinicore/nn/__init__.py new file mode 100644 index 000000000..270dcebaa --- /dev/null +++ b/python/infinicore/nn/__init__.py @@ -0,0 +1 @@ +from .modules import * diff --git a/python/infinicore/nn/modules/__init__.py b/python/infinicore/nn/modules/__init__.py new file mode 100644 index 000000000..e7f51953b --- /dev/null +++ b/python/infinicore/nn/modules/__init__.py @@ -0,0 +1 @@ +from .module import InfiniCoreModule as Module diff --git a/python/infinicore/nn/modules/module.py b/python/infinicore/nn/modules/module.py new file mode 100644 index 000000000..ccefbe635 --- /dev/null +++ b/python/infinicore/nn/modules/module.py @@ -0,0 +1,717 @@ +# Copyright (c) 2025, InfiniCore +# +# This file contains modified code derived from PyTorch's `torch.nn.Module` +# implementation, which is licensed under the BSD 3-Clause License. +# +# The modifications include adaptations for the InfiniCore framework, custom +# parameter/buffer registration mechanisms, and simplified state_dict handling. +# +# Original PyTorch source: +# https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py +# +# Referencing PyTorch v2.4.0 +# +# The use of this file is governed by the BSD 3-Clause License. + +from collections import OrderedDict, namedtuple +import itertools +import warnings + +import torch + +from typing import Union, Tuple, Any, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +_EXTRA_STATE_KEY_SUFFIX = '_extra_state' + +T = TypeVar('T', bound='InfiniCoreModule') + +class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): + def __repr__(self): + if not self.missing_keys and not self.unexpected_keys: + return '' + return super().__repr__() + + __str__ = __repr__ + + +class InfiniCoreModule: + r"""Base class for InfiniCore neural network modules. + Your models should also subclass this class. + + Modules can also contain other Modules, allowing + to nest them in a tree structure. + """ + + _version: int = 1 + + training: bool + _parameters: Dict[str, Optional[torch.nn.Parameter]] + _buffers: Dict[str, Optional[torch.Tensor]] + _non_persistent_buffers_set: Set[str] + _modules: Dict[str, Optional['InfiniCoreModule']] + + def __init__(self): + super().__setattr__("_parameters", OrderedDict()) + super().__setattr__("_buffers", OrderedDict()) + super().__setattr__("_non_persistent_buffers_set", set()) + super().__setattr__("_modules", OrderedDict()) + + def __getattr__(self, name: str) -> Any: + if "_parameters" in self.__dict__: + _parameters = self.__dict__["_parameters"] + if name in _parameters: + return _parameters[name] + if "_buffers" in self.__dict__: + _buffers = self.__dict__["_buffers"] + if name in _buffers: + return _buffers[name] + if "_modules" in self.__dict__: + modules = self.__dict__["_modules"] + if name in modules: + return modules[name] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __setattr__(self, name: str, value: Union[torch.Tensor, 'InfiniCoreModule']) -> None: + def remove_from(*dicts_or_sets) -> None: + for d in dicts_or_sets: + if name in d: + if isinstance(d, dict): + del d[name] + else: + d.discard(name) + + params = self.__dict__.get("_parameters") + if isinstance(value, torch.nn.Parameter): + if params is None: + raise AttributeError( + "cannot assign parameters before Module.__init__() call" + ) + remove_from( + self.__dict__, + self._buffers, + self._modules, + self._non_persistent_buffers_set, + ) + self.register_parameter(name, value) + elif params is not None and name in params: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as parameter '{name}' " + "(torch.nn.Parameter or None expected)" + ) + self.register_parameter(name, value) + else: + modules = self.__dict__.get("_modules") + if isinstance(value, (torch.nn.Module)): + if modules is None: + raise AttributeError( + "cannot assign module before Module.__init__() call" + ) + remove_from( + self.__dict__, + self._parameters, + self._buffers, + self._non_persistent_buffers_set, + ) + modules[name] = value + elif modules is not None and name in modules: + if value is not None: + raise TypeError( + f"cannot assign '{torch.typename(value)}' as child module '{name}' " + "(torch.nn.Module or None expected)" + ) + modules[name] = value + else: + buffers = self.__dict__.get("_buffers") + if buffers is not None and name in buffers: + if value is not None and not isinstance(value, torch.Tensor): + raise TypeError(f"cannot assign '{torch.typename(value)}' as buffer '{name}' " + "(torch.Tensor or None expected)" + ) + buffers[name] = value + else: + super().__setattr__(name, value) + + def register_buffer(self, name: str, tensor: Optional[torch.tensor], persistent: bool = True) -> None: + r"""Adds a buffer to the module. + + This is typically used to register a buffer that should not to be + considered a model parameter.Buffers, by default, are persistent + and will be saved alongside parameters. This behavior can be changed + by setting :attr:`persistent` to ``False``. The only difference between + a persistent buffer and a non-persistent buffer is that the latter + will not be a part of this module's :attr:`state_dict`. + + Buffers can be accessed as attributes using given names. + + Args: + name (str): name of the buffer. The buffer can be accessed + from this module using the given name + tensor (Tensor or None): buffer to be registered. If ``None``, then operations + that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, + the buffer is **not** included in the module's :attr:`state_dict`. + persistent (bool): whether the buffer is part of this module's + :attr:`state_dict`. + + """ + if '_buffers' not in self.__dict__: + raise AttributeError( + "cannot assign buffer before Module.__init__() call") + elif not isinstance(name, str): + raise TypeError("buffer name should be a string. " + "Got {}".format(torch.typename(name))) + elif '.' in name: + raise KeyError("buffer name can't contain \".\"") + elif name == '': + raise KeyError("buffer name can't be empty string \"\"") + elif hasattr(self, name) and name not in self._buffers: + raise KeyError("attribute '{}' already exists".format(name)) + elif tensor is not None and not isinstance(tensor, torch.Tensor): + raise TypeError("cannot assign '{}' object to buffer '{}' " + "(torch Tensor or None required)" + .format(torch.typename(tensor), name)) + else: + self._buffers[name] = tensor + if persistent: + self._non_persistent_buffers_set.discard(name) + else: + self._non_persistent_buffers_set.add(name) + + + def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: + r"""Add a parameter to the module. + + The parameter can be accessed as an attribute using given name. + + Args: + name (str): name of the parameter. The parameter can be accessed + from this module using the given name + param (Parameter or None): parameter to be added to the module. If + ``None``, then operations that run on parameters, such as :attr:`cuda`, + are ignored. If ``None``, the parameter is **not** included in the + module's :attr:`state_dict`. + """ + if "_parameters" not in self.__dict__: + raise AttributeError( + "cannot assign parameter before Module.__init__() call" + ) + + elif not isinstance(name, str): + raise TypeError( + f"parameter name should be a string. Got {torch.typename(name)}" + ) + elif "." in name: + raise KeyError('parameter name can\'t contain "."') + elif name == "": + raise KeyError('parameter name can\'t be empty string ""') + elif hasattr(self, name) and name not in self._parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._parameters[name] = None + elif not isinstance(param, torch.nn.Parameter): + raise TypeError( + f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " + "(torch.nn.Parameter or None required)" + ) + else: + self._parameters[name] = param + + def get_extra_state(self) -> Any: + """Return any extra state to include in the module's state_dict. + + Implement this and a corresponding :func:`set_extra_state` for your module + if you need to store extra state. This function is called when building the + module's `state_dict()`. + + Note that extra state should be picklable to ensure working serialization + of the state_dict. We only provide provide backwards compatibility guarantees + for serializing Tensors; other objects may break backwards compatibility if + their serialized pickled form changes. + + Returns: + object: Any extra state to store in the module's state_dict + """ + raise RuntimeError( + "Reached a code path in Module.get_extra_state() that should never be called. " + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", InfiniCoreModule.get_extra_state) is not InfiniCoreModule.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns + # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. + T_destination = TypeVar('T_destination', bound=Dict[str, Any]) + + @overload + def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: + ... + + @overload + def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: + ... + + # TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows. + # Also remove the logic for arg parsing together. + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + r"""Returns a dictionary containing references to the whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + .. note:: + The returned object is a shallow copy. It contains references + to the module's parameters and buffers. + + .. warning:: + Currently ``state_dict()`` also accepts positional arguments for + ``destination``, ``prefix`` and ``keep_vars`` in order. However, + this is being deprecated and keyword arguments will be enforced in + future releases. + + .. warning:: + Please avoid the use of argument ``destination`` as it is not + designed for end-users. + + Args: + destination (dict, optional): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (str, optional): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (bool, optional): by default the :class:`~torch.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + + Returns: + dict: + a dictionary containing a whole state of the module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> module.state_dict().keys() + ['bias', 'weight'] + + """ + + # TODO: Remove `args` and the parsing logic when BC allows. + if len(args) > 0: + # DeprecationWarning is ignored by default + warnings.warn( + "Positional args are being deprecated, use kwargs instead. ", + FutureWarning, + stacklevel=2, + ) + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == '': + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] + + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + + local_metadata = dict(version=self._version) + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata + + self._save_to_state_dict(destination, prefix, keep_vars) + for name, module in self._modules.items(): + if module is not None: + module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) + return destination + + def set_extra_state(self, state: Any): + """ + This function is called from :func:`load_state_dict` to handle any extra state + found within the `state_dict`. Implement this function and a corresponding + :func:`get_extra_state` for your module if you need to store extra state within its + `state_dict`. + + Args: + state (dict): Extra state from the `state_dict` + """ + raise RuntimeError( + "Reached a code path in Module.set_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "to report this bug.") + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append('While copying the parameter named "{}", ' + 'expected torch.Tensor or Tensor-like object from checkpoint but ' + 'received {}' + .format(key, type(input_param))) + continue + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.' + .format(key, input_param.shape, param.shape)) + continue + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.' + .format(key, param.size(), input_param.size(), ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", InfiniCoreModule.set_extra_state) is not InfiniCoreModule.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):].split(".", 1) + # Must be Module if it have attributes + if len(input_name) > 1: + if input_name[0] not in self._modules: + unexpected_keys.append(key) + elif input_name[0] not in local_state: + unexpected_keys.append(key) + + def load_state_dict(self, state_dict: Mapping[str, Any], + strict: bool = True): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + if not isinstance(state_dict, Mapping): + raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + def load(module, local_state_dict, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + child_prefix = prefix + name + '.' + child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} + load(child, child_state_dict, child_prefix) + + load(self, state_dict) + del load + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys))) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def children(self) -> Iterator['InfiniCoreModule']: + r"""Returns an iterator over immediate children modules. + + Yields: + Module: a child module + """ + for name, module in self.named_children(): + yield module + + def named_children(self) -> Iterator[Tuple[str, 'InfiniCoreModule']]: + r"""Returns an iterator over immediate children modules, yielding both + the name of the module as well as the module itself. + + Yields: + (str, Module): Tuple containing a name and child module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, module in model.named_children(): + >>> if name in ['conv4', 'conv5']: + >>> print(module) + + """ + memo = set() + for name, module in self._modules.items(): + if module is not None and module not in memo: + memo.add(module) + yield name, module + + + def train(self: T, mode: bool = True) -> T: + r"""Sets the module in training mode. + + This has any effect only on certain modules. See documentations of + particular modules for details of their behaviors in training/evaluation + mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, + etc. + + Args: + mode (bool): whether to set training mode (``True``) or evaluation + mode (``False``). Default: ``True``. + + Returns: + Module: self + """ + if not isinstance(mode, bool): + raise ValueError("training mode is expected to be boolean") + self.training = mode + for module in self.children(): + module.train(mode) + return self + + def eval(self: T) -> T: + r"""Sets the module in evaluation mode. + + This has any effect only on certain modules. See documentations of + particular modules for details of their behaviors in training/evaluation + mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, + etc. + + This is equivalent with :meth:`self.train(False) `. + + See :ref:`locally-disable-grad-doc` for a comparison between + `.eval()` and several similar mechanisms that may be confused with it. + + Returns: + Module: self + """ + return self.train(False) + + + def _apply(self, fn, recurse=True): + if recurse: + for module in self.children(): + module._apply(fn) + + def compute_should_use_set_data(tensor, tensor_applied): + if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): + # If the new tensor has compatible tensor type as the existing tensor, + # the current behavior is to change the tensor in-place using `.data =`, + # and the future behavior is to overwrite the existing tensor. However, + # changing the current behavior is a BC-breaking change, and we want it + # to happen in future releases. So for now we introduce the + # `torch.__future__.get_overwrite_module_params_on_conversion()` + # global flag to let the user control whether they want the future + # behavior of overwriting the existing tensor or not. + return not torch.__future__.get_overwrite_module_params_on_conversion() + else: + return False + + should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() + + for key, param in self._parameters.items(): + if param is None: + continue + # Tensors stored in modules are graph leaves, and we don't want to + # track autograd history of `param_applied`, so we have to use + # `with torch.no_grad():` + with torch.no_grad(): + param_applied = fn(param) + p_should_use_set_data = compute_should_use_set_data(param, param_applied) + + # subclasses may have multiple child tensors so we need to use swap_tensors + p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) + + param_grad = param.grad + if p_should_use_swap_tensors: + try: + if param_grad is not None: + # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. + # Decrement use count of the gradient by setting to None + param.grad = None + param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad) + torch.utils.swap_tensors(param, param_applied) + except Exception as e: + if param_grad is not None: + param.grad = param_grad + raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e + out_param = param + elif p_should_use_set_data: + param.data = param_applied + out_param = param + else: + assert isinstance(param, torch.nn.Parameter) + assert param.is_leaf + out_param = torch.nn.Parameter(param_applied, param.requires_grad) + self._parameters[key] = out_param + + if param_grad is not None: + with torch.no_grad(): + grad_applied = fn(param_grad) + g_should_use_set_data = compute_should_use_set_data(param_grad, grad_applied) + if p_should_use_swap_tensors: + grad_applied.requires_grad_(param_grad.requires_grad) + try: + torch.utils.swap_tensors(param_grad, grad_applied) + except Exception as e: + raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}.grad") from e + out_param.grad = param_grad + elif g_should_use_set_data: + assert out_param.grad is not None + out_param.grad.data = grad_applied + else: + assert param_grad.is_leaf + out_param.grad = grad_applied.requires_grad_(param_grad.requires_grad) + + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + + return self + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + if dtype is not None: + if not (dtype.is_floating_point or dtype.is_complex): + raise TypeError('nn.Module.to only accepts floating point or complex ' + f'dtypes, but got desired dtype={dtype}') + if dtype.is_complex: + warnings.warn( + "Complex modules are a new feature under active development whose design may change, " + "and some modules might not work as expected when using complex tensors as parameters or buffers. ") + + def convert(t): + try: + if convert_to_format is not None and t.dim() in (4, 5): + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + memory_format=convert_to_format, + ) + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + ) + except NotImplementedError as e: + if str(e) == "Cannot copy out of meta tensor; no data!": + raise NotImplementedError( + f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() " + f"when moving module from meta to a different device." + ) from None + else: + raise + + return self._apply(convert) diff --git a/test/infinicore/infinicore_nn_test.py b/test/infinicore/infinicore_nn_test.py new file mode 100644 index 000000000..e21dce723 --- /dev/null +++ b/test/infinicore/infinicore_nn_test.py @@ -0,0 +1,146 @@ +import safetensors.torch +import torch +import torch.nn as nn +import safetensors + +# ============================================================ +# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径 +# ============================================================ +import sys +import os + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../python/infinicore'))) + +save_dir = os.path.join(os.path.dirname(__file__), '../../tmp') +os.makedirs(save_dir, exist_ok=True) +save_path = os.path.join(save_dir, "torch_convnet_with_param.safetensors") + +# ============================================================ +# 1. 使用 PyTorch 定义并保存模型 +# ============================================================ +print("===== 开始 CPU 一致性测试 =====") + +class TorchConvNet(nn.Module): + def __init__(self, in_ch=3, hidden_ch=8, out_ch=3): + super().__init__() + # 主体网络 + self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(hidden_ch) + self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(hidden_ch) + self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=1) + self.relu = nn.ReLU() + + # 自定义 Parameter + self.scale = nn.Parameter(torch.ones(1) * 0.5) + + # 注册一个 buffer + self.register_buffer("offset", torch.tensor(0.1)) + + def forward(self, x): + x = self.relu(self.bn1(self.conv1(x))) + x = self.relu(self.bn2(self.conv2(x))) + x = self.conv3(x) + # 应用自定义参数和 buffer + x = x * self.scale + self.offset + return x + + +# ===== 保存 Torch 模型 ===== +torch_model = TorchConvNet() +torch_state_dict = torch_model.state_dict() +safetensors.torch.save_file(torch_state_dict, save_path) + +# ============================================================ +# 2. 使用 torch 方式加载并推理 +# ============================================================ + +torch_model_infer = TorchConvNet() +torch_model_infer.load_state_dict(safetensors.torch.load_file(save_path)) +torch_model_infer.eval() + +input = torch.rand(1, 3, 8, 8) +torch_model_out = torch_model_infer(input) + +# ============================================================ +# 3. 使用 infiniCore.nn.module 加载并推理 +# ============================================================ + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../python/infinicore'))) + +from nn import Module + +class InfiniCoreConvNet(Module): + def __init__(self, in_ch=3, hidden_ch=8, out_ch=3): + super().__init__() + self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(hidden_ch) + self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(hidden_ch) + self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=1) + self.relu = nn.ReLU() + + # 保持与 Torch 模型一致的自定义参数和 buffer + self.scale = nn.Parameter(torch.ones(1) * 0.5) + self.register_buffer("offset", torch.tensor(0.1)) + + def forward(self, x): + x = self.relu(self.bn1(self.conv1(x))) + x = self.relu(self.bn2(self.conv2(x))) + x = self.conv3(x) + x = x * self.scale + self.offset + return x + +# ===== 使用 InfiniCoreConvNet 读取 safetensors 并推理 ===== +infinicore_model_infer = InfiniCoreConvNet() +infinicore_model_infer.load_state_dict(safetensors.torch.load_file(save_path)) +infinicore_model_infer.eval() + +infinicore_model_out = infinicore_model_infer.forward(input) + +# ============================================================ +# 4. 对比结果 +# ============================================================ + +diff_cpu = (infinicore_model_out - torch_model_out).abs().max().item() +print(f"InfiniCoreModule 与 Torch (CPU) 最大误差: {diff_cpu:.6e}") + +if diff_cpu < 1e-9: + print("CPU 模式下 InfiniCore 与 Torch 输出完全一致.") +else: + print("CPU 模式下输出存在差异.") + + +# ============================================================ +# 5. GPU 一致性测试(可选) +# ============================================================ + +if torch.cuda.is_available(): + print("\n===== 开始 GPU 一致性测试 =====") + + # 将模型与输入都迁移到 GPU + torch_model_infer_gpu = TorchConvNet().to("cuda") + torch_model_infer_gpu.load_state_dict(safetensors.torch.load_file(save_path)) + torch_model_infer_gpu.eval() + + infinicore_model_infer_gpu = InfiniCoreConvNet().to("cuda") + infinicore_model_infer_gpu.load_state_dict(safetensors.torch.load_file(save_path)) + infinicore_model_infer_gpu.eval() + + # 生成 GPU 输入 + input_gpu = input.to("cuda") + + # 分别前向推理 + torch_out_gpu = torch_model_infer_gpu(input_gpu) + infinicore_out_gpu = infinicore_model_infer_gpu.forward(input_gpu) + + # 结果比较 + diff_gpu = (infinicore_out_gpu - torch_out_gpu).abs().max().item() + print(f"InfiniCoreModule 与 Torch (GPU) 最大误差: {diff_gpu:.6e}") + + if diff_gpu < 1e-9: + print("GPU 模式下 InfiniCore 与 Torch 输出完全一致.") + else: + print("GPU 模式下输出存在差异.") +else: + print("\n 未检测到 GPU,跳过 GPU 一致性测试。") \ No newline at end of file