diff --git a/python/infinicore/nn/__init__.py b/python/infinicore/nn/__init__.py index e08b88af4..73c9f0aaa 100644 --- a/python/infinicore/nn/__init__.py +++ b/python/infinicore/nn/__init__.py @@ -1,3 +1,5 @@ from infinicore.nn import functional +from infinicore.nn.modules import * # noqa: F403 +from infinicore.nn.parameter import InfiniCoreParameter as Parameter -__all__ = ["functional"] +__all__ = ["functional", "Parameter"] diff --git a/python/infinicore/nn/modules/__init__.py b/python/infinicore/nn/modules/__init__.py new file mode 100644 index 000000000..b013a0f18 --- /dev/null +++ b/python/infinicore/nn/modules/__init__.py @@ -0,0 +1,4 @@ +from .container import InfiniCoreModuleList as ModuleList +from .module import InfiniCoreModule as Module + +__all__ = ["ModuleList", "Module"] diff --git a/python/infinicore/nn/modules/container.py b/python/infinicore/nn/modules/container.py new file mode 100644 index 000000000..2ec3ad426 --- /dev/null +++ b/python/infinicore/nn/modules/container.py @@ -0,0 +1,188 @@ +# ============================================ +# Copyright (c) 2025, InfiniCore +# +# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList +# but based on InfiniCoreModule for inference purposes. + +import operator +from collections import OrderedDict +from itertools import chain +from typing import Iterator, List, Optional, Sequence, TypeVar, Union + +from .module import InfiniCoreModule as Module + +# Define type variable for module compatibility (supports InfiniCoreModule) +ModuleType = TypeVar("ModuleType", bound=Union["Module"]) + + +class InfiniCoreModuleList(Module): + r"""Holds submodules in a list. + + InfiniCoreModuleList can be indexed like a regular Python list, but + modules it contains are properly registered, and will be visible by all + InfiniCoreModule methods. + + Args: + modules (iterable, optional): an iterable of modules to add + + Example:: + + >>> class MyModel(InfiniCoreModule): + ... def __init__(self): + ... super().__init__() + ... self.linears = InfiniCoreModuleList([ + ... torch.nn.Linear(10, 10) for i in range(10) + ... ]) + ... + ... def forward(self, x): + ... # ModuleList can act as an iterable, or be indexed using ints + ... for i, l in enumerate(self.linears): + ... x = self.linears[i // 2](x) + l(x) + ... return x + """ + + def __init__(self, modules: Optional[Sequence[ModuleType]] = None): + super().__init__() + if modules is not None: + self += modules + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f"index {idx} is out of range") + if idx < 0: + idx += len(self) + return str(idx) + + def __getitem__( + self, idx: Union[int, slice] + ) -> Union[ModuleType, "InfiniCoreModuleList"]: + if isinstance(idx, slice): + return self.__class__(list(self._modules.values())[idx]) + else: + return self._modules[self._get_abs_string_index(idx)] + + def __setitem__(self, idx: int, module: ModuleType) -> None: + idx = self._get_abs_string_index(idx) + # Use add_module to register module + self.add_module(idx, module) + + def __delitem__(self, idx: Union[int, slice]) -> None: + if isinstance(idx, slice): + indices_to_delete = list(range(len(self._modules)))[idx] + for k in indices_to_delete: + if str(k) in self._modules: + del self._modules[str(k)] + else: + idx_str = self._get_abs_string_index(idx) + if idx_str in self._modules: + del self._modules[idx_str] + + # To preserve numbering, self._modules is being reconstructed with modules after deletion + if len(self._modules) > 0: + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[ModuleType]: + return iter(self._modules.values()) + + def __iadd__(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList": + return self.extend(modules) + + def __add__( + self, other: Union[Sequence[ModuleType], "InfiniCoreModuleList"] + ) -> "InfiniCoreModuleList": + r"""Return a new InfiniCoreModuleList by concatenating with another iterable. + + Args: + other (iterable): iterable of modules to concatenate + """ + if not isinstance(other, (list, tuple, InfiniCoreModuleList)): + raise TypeError( + f"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, " + f"got {type(other).__name__}" + ) + + combined = InfiniCoreModuleList() + for i, module in enumerate(chain(self, other)): + combined.add_module(str(i), module) + return combined + + def append(self, module: ModuleType) -> "InfiniCoreModuleList": + r"""Append a given module to the end of the list. + + Args: + module (InfiniCoreModule): module to append + """ + self.add_module(str(len(self)), module) + return self + + def extend(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList": + r"""Append modules from a Python iterable to the end of the list. + + Args: + modules (iterable): iterable of modules to append + """ + if not isinstance(modules, (list, tuple)): + try: + modules = list(modules) + except TypeError: + raise TypeError( + f"InfiniCoreModuleList.extend should be called with an " + f"iterable, but got {type(modules).__name__}" + ) + + offset = len(self) + for i, module in enumerate(modules): + self.add_module(str(offset + i), module) + return self + + def insert(self, index: int, module: ModuleType) -> None: + r"""Insert a given module before a given index in the list. + + Args: + index (int): index to insert. + module ( InfiniCoreModule): module to insert + """ + for i in range(len(self._modules), index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + + def pop(self, idx: int = -1) -> ModuleType: + r"""Remove and return a module at the given index. + + Args: + idx (int): index of the module to pop. Default: -1 (last module) + + Returns: + Module: the module that was removed + """ + idx_str = self._get_abs_string_index(idx) + module = self._modules[idx_str] + # Use __delitem__ to ensure proper cleanup + self.__delitem__(int(idx_str)) + return module + + def __repr__(self) -> str: + """Return a string representation of the ModuleList.""" + if len(self) == 0: + return self.__class__.__name__ + "()" + + lines = [] + for i, module in enumerate(self): + lines.append(f"({i}): {repr(module)}") + + main_str = self.__class__.__name__ + "(\n " + main_str += "\n ".join(lines) + "\n)" + return main_str + + def __dir__(self) -> List[str]: + """Return a list of attribute names, excluding numeric keys.""" + keys = super().__dir__() + # Filter out numeric keys to avoid cluttering dir() output + keys = [key for key in keys if not key.isdigit()] + return keys diff --git a/python/infinicore/nn/modules/module.py b/python/infinicore/nn/modules/module.py new file mode 100644 index 000000000..d21223903 --- /dev/null +++ b/python/infinicore/nn/modules/module.py @@ -0,0 +1,857 @@ +# Copyright (c) 2025, InfiniCore +# +# This file contains modified code derived from PyTorch's `torch.nn.Module` +# implementation, which is licensed under the BSD 3-Clause License. +# +# The modifications include adaptations for the InfiniCore framework, custom +# parameter/buffer registration mechanisms, and simplified state_dict handling. +# +# Original PyTorch source: +# https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py +# +# Referencing PyTorch v2.4.0 +# +# The use of this file is governed by the BSD 3-Clause License. + +import itertools +import warnings +from collections import OrderedDict, namedtuple +from typing import ( + Any, + Dict, + Iterator, + List, + Mapping, + Optional, + Set, + Tuple, + TypeVar, + Union, + overload, +) + +import infinicore + +from ...tensor import Tensor +from ..parameter import InfiniCoreParameter as Parameter + +_EXTRA_STATE_KEY_SUFFIX = "_extra_state" +T = TypeVar("T", bound="InfiniCoreModule") + + +class _IncompatibleKeys( + namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]) +): + def __repr__(self): + if not self.missing_keys and not self.unexpected_keys: + return "" + return super().__repr__() + + __str__ = __repr__ + + +class InfiniCoreModule: + r"""Base class for InfiniCore neural network modules. + Your models should also subclass this class. + Modules can also contain other Modules, allowing to nest them in a tree structure. + """ + + _version: int = 1 + _parameters: Dict[str, Optional[Parameter]] + _buffers: Dict[str, Optional[Tensor]] + _non_persistent_buffers_set: Set[str] + _modules: Dict[str, Optional["InfiniCoreModule"]] + + def __init__(self): + super().__setattr__("_parameters", OrderedDict()) + super().__setattr__("_buffers", OrderedDict()) + super().__setattr__("_non_persistent_buffers_set", set()) + super().__setattr__("_modules", OrderedDict()) + + def __getattr__(self, name: str) -> Any: + if "_parameters" in self.__dict__: + _parameters = self.__dict__["_parameters"] + if name in _parameters: + return _parameters[name] + + if "_buffers" in self.__dict__: + _buffers = self.__dict__["_buffers"] + if name in _buffers: + return _buffers[name] + + if "_modules" in self.__dict__: + modules = self.__dict__["_modules"] + if name in modules: + return modules[name] + + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __setattr__(self, name: str, value: Union[Tensor, "InfiniCoreModule"]) -> None: + def remove_from(*dicts_or_sets) -> None: + for d in dicts_or_sets: + if name in d: + if isinstance(d, dict): + del d[name] + else: + d.discard(name) + + params = self.__dict__.get("_parameters") + if params is None: + raise AttributeError( + "cannot assign parameters before Module.__init__() call" + ) + + if isinstance(value, Parameter): # the value is of type Parameter + remove_from( + self.__dict__, + self._buffers, + self._modules, + self._non_persistent_buffers_set, + ) + self.register_parameter(name, value) + elif name in params: # value will overwrite the name of params. + if not isinstance(value, Tensor): + raise TypeError( + f"cannot assign 'value' as parameter '{name}' (infinicore.nn.Parameter, Parameter or None expected)" + ) + self.register_parameter(name, value) + + else: + modules = self.__dict__.get("_modules") + if modules is None: + raise AttributeError( + "cannot assign module before Module.__init__() call" + ) + + if isinstance(value, InfiniCoreModule): + remove_from( + self.__dict__, + self._parameters, + self._buffers, + self._non_persistent_buffers_set, + ) + modules[name] = value + elif name in modules: # Do not overwrite this variable + raise TypeError( + f"cannot assign 'value' as child module '{name}' (infinicore.nn.Module or None expected)" + ) + else: + buffers = self.__dict__.get("_buffers") + if buffers is not None and name in buffers: + if value is not None and not isinstance(value, Tensor): + raise TypeError( + f"cannot assign 'value' as buffer '{name}' " + "(torch.Tensor or None expected)" + ) + buffers[name] = value + else: + super().__setattr__(name, value) + + def __call__(self, *input, **kwargs): + return self.forward(*input, **kwargs) + + def register_buffer( + self, name: str, tensor: Optional[Tensor], persistent: bool = True + ) -> None: + r"""Adds a buffer to the module. + + This is typically used to register a buffer that should not to be + considered a model parameter.Buffers, by default, are persistent + and will be saved alongside parameters. This behavior can be changed + by setting :attr:`persistent` to ``False``. The only difference between + a persistent buffer and a non-persistent buffer is that the latter + will not be a part of this module's :attr:`state_dict`. + + Buffers can be accessed as attributes using given names. + + Args: + name (str): name of the buffer. The buffer can be accessed + from this module using the given name + tensor (Tensor or None): buffer to be registered. If ``None``, then operations + that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, + the buffer is **not** included in the module's :attr:`state_dict`. + persistent (bool): whether the buffer is part of this module's + :attr:`state_dict`. + + """ + if "_buffers" not in self.__dict__: + raise AttributeError("cannot assign buffer before Module.__init__() call") + elif not isinstance(name, str): + raise TypeError("buffer name should be a string. Got {}".format("name")) + elif "." in name: + raise KeyError('buffer name can\'t contain "."') + elif name == "": + raise KeyError('buffer name can\'t be empty string ""') + elif hasattr(self, name) and name not in self._buffers: + raise KeyError("attribute '{}' already exists".format(name)) + elif tensor is not None and not isinstance(tensor, Tensor): + raise TypeError( + "cannot assign '{}' object to buffer '{}' " + "(torch Tensor or None required)".format("tensor", name) + ) + else: + self._buffers[name] = tensor + if persistent: + self._non_persistent_buffers_set.discard(name) + else: + self._non_persistent_buffers_set.add(name) + + def add_module(self, name: str, module: Optional["InfiniCoreModule"]) -> None: + r"""Add a child module to the current module. + + The module can be accessed as an attribute using the given name. + + Args: + name (str): name of the child module. The child module can be + accessed from this module using the given name + module (Module or None): child module to be added to the module. If + ``None``, then operations that run on modules, such as :attr:`eval`, + are ignored. If ``None``, the module is **not** included in the + module's :attr:`children`. + """ + if not isinstance(name, str): + raise TypeError(f"module name should be a string. Got {name}") + elif "." in name: + raise KeyError(f'module name can\'t contain ".", got: {name}') + elif name == "": + raise KeyError('module name can\'t be empty string ""') + elif hasattr(self, name) and name not in self._modules: + raise KeyError(f"attribute '{name}' already exists") + + if module is not None and not isinstance(module, InfiniCoreModule): + raise TypeError(f"{module} is not a Module subclass") + + self._modules[name] = module + + def register_parameter(self, name: str, param: Parameter) -> None: + r"""Add a parameter to the module. + + The parameter can be accessed as an attribute using given name. + + Args: + name (str): name of the parameter. The parameter can be accessed + from this module using the given name + param (Parameter or None): parameter to be added to the module. If + ``None``, then operations that run on parameters, such as :attr:`cuda`, + are ignored. If ``None``, the parameter is **not** included in the + module's :attr:`state_dict`. + """ + + if "_parameters" not in self.__dict__: + raise AttributeError( + "cannot assign parameter before Module.__init__() call" + ) + elif not isinstance(name, str): + raise TypeError("parameter name should be a string.") + elif "." in name: + raise KeyError('parameter name can\'t contain "."') + elif name == "": + raise KeyError('parameter name can\'t be empty string ""') + elif hasattr(self, name) and name not in self._parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._parameters[name] = None # 竟然可以是None + else: + if not isinstance(param, (Parameter, Tensor)): + raise TypeError( + f"cannot assign 'param' object to parameter '{name}' " + "(infinicore.nn.Parameter, Parameter or None required)" + ) + + self._parameters[name] = param + super().__setattr__(name, param) + + def get_extra_state(self) -> Any: + """Return any extra state to include in the module's state_dict. + + Implement this and a corresponding :func:`set_extra_state` for your module + if you need to store extra state. This function is called when building the + module's `state_dict()`. + + Note that extra state should be picklable to ensure working serialization + of the state_dict. We only provide provide backwards compatibility guarantees + for serializing Tensors; other objects may break backwards compatibility if + their serialized pickled form changes. + + Returns: + object: Any extra state to store in the module's state_dict + """ + raise RuntimeError( + "Reached a code path in Module.get_extra_state() that should never be called. " + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(self.__class__, "get_extra_state", InfiniCoreModule.get_extra_state) + is not InfiniCoreModule.get_extra_state + ): + destination[extra_state_key] = self.get_extra_state() + + # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns + # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. + T_destination = TypeVar("T_destination", bound=Dict[str, Any]) + + @overload + def state_dict( + self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ... + ) -> T_destination: ... + + @overload + def state_dict( + self, *, prefix: str = ..., keep_vars: bool = ... + ) -> Dict[str, Any]: ... + + # TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows. + # Also remove the logic for arg parsing together. + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + r"""Returns a dictionary containing references to the whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + .. note:: + The returned object is a shallow copy. It contains references + to the module's parameters and buffers. + + .. warning:: + Currently ``state_dict()`` also accepts positional arguments for + ``destination``, ``prefix`` and ``keep_vars`` in order. However, + this is being deprecated and keyword arguments will be enforced in + future releases. + + .. warning:: + Please avoid the use of argument ``destination`` as it is not + designed for end-users. + + Args: + destination (dict, optional): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (str, optional): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (bool, optional): by default the :class:`~torch.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + + Returns: + dict: + a dictionary containing a whole state of the module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> module.state_dict().keys() + ['bias', 'weight'] + + """ + + # TODO: Remove `args` and the parsing logic when BC allows. + if len(args) > 0: + # DeprecationWarning is ignored by default + warnings.warn( + "Positional args are being deprecated, use kwargs instead. ", + FutureWarning, + stacklevel=2, + ) + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == "": + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] + + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + + local_metadata = dict(version=self._version) + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata + + self._save_to_state_dict(destination, prefix, keep_vars) + for name, module in self._modules.items(): + if module is not None: + module.state_dict( + destination=destination, + prefix=prefix + name + ".", + keep_vars=keep_vars, + ) + return destination + + def set_extra_state(self, state: Any): + """ + This function is called from :func:`load_state_dict` to handle any extra state + found within the `state_dict`. Implement this function and a corresponding + :func:`get_extra_state` for your module if you need to store extra state within its + `state_dict`. + + Args: + state (dict): Extra state from the `state_dict` + """ + raise RuntimeError( + "Reached a code path in Module.set_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "to report this bug." + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + persistent_buffers = { + k: v + for k, v in self._buffers.items() + if k not in self._non_persistent_buffers_set + } + local_name_params = itertools.chain( + self._parameters.items(), persistent_buffers.items() + ) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + + # input_param must be of type infinicore.Tensor + if not isinstance(input_param, Tensor): + raise TypeError( + f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}" + ) + + if ( + (param.shape == input_param.shape) + and (param.dtype == input_param.dtype) + and (param.device == input_param.device) + ): + param.copy_(input_param) + else: + print(f"param '{name}' don't match input_param '{key}'") + setattr(self, name, input_param) + + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(self.__class__, "set_extra_state", InfiniCoreModule.set_extra_state) + is not InfiniCoreModule.set_extra_state + ): + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix): + input_name = key[len(prefix) :].split(".", 1) + # Must be Module if it have attributes + if len(input_name) > 1: + if input_name[0] not in self._modules: + unexpected_keys.append(key) + elif input_name[0] not in local_state: + unexpected_keys.append(key) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + if not isinstance(state_dict, Mapping): + raise TypeError( + "Expected state_dict to be dict-like, got {}.".format(type(state_dict)) + ) + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + state_dict._metadata = metadata # type: ignore[attr-defined] + + def load(module, local_state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + local_state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, + ) + for name, child in module._modules.items(): + if child is not None: + child_prefix = prefix + name + "." + child_state_dict = { + k: v + for k, v in local_state_dict.items() + if k.startswith(child_prefix) + } + load(child, child_state_dict, child_prefix) # noqa: F821 + + load(self, state_dict) + del load + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, + "Unexpected key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in unexpected_keys) + ), + ) + if len(missing_keys) > 0: + error_msgs.insert( + 0, + "Missing key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in missing_keys) + ), + ) + + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def parameters(self, recurse: bool = True) -> Iterator["Parameter"]: + r"""Returns an iterator over module parameters. + + Args: + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + Parameter: module parameter + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for param in model.parameters(): + ... print(type(param), param.size()) + + """ + for name, param in self.named_parameters(recurse=recurse): + yield param + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, "Parameter"]]: + r"""Returns an iterator over module parameters, yielding both the + name of the parameter as well as the parameter itself. + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + (str, Parameter): Tuple containing the name and parameter + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, param in self.named_parameters(): + ... if name in ['bias']: + ... print(param.size()) + + """ + gen = self._named_members( + lambda module: module._parameters.items(), prefix=prefix, recurse=recurse + ) + for elem in gen: + yield elem + + def buffers(self, recurse: bool = True) -> Iterator[Tensor]: + r"""Returns an iterator over module buffers. + + Args: + recurse (bool): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. + + Yields: + torch.Tensor: module buffer + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for buf in model.buffers(): + ... print(type(buf), buf.size()) + + """ + for name, buf in self.named_buffers(recurse=recurse): + yield buf + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, Tensor]]: + r"""Returns an iterator over module buffers, yielding both the + name of the buffer as well as the buffer itself. + + Args: + prefix (str): prefix to prepend to all buffer names. + recurse (bool): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. + + Yields: + (str, torch.Tensor): Tuple containing the name and buffer + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, buf in self.named_buffers(): + ... if name in ['running_mean']: + ... print(buf.size()) + + """ + memo = set() + modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] + for module_prefix, module in modules: + for k, v in module._buffers.items(): + if v is None or v in memo: + continue + if k in module._non_persistent_buffers_set: + continue + memo.add(v) + name = module_prefix + ("." if module_prefix else "") + k + yield (name, v) + + def _named_members(self, get_members_fn, prefix="", recurse=True): + r"""Helper method to yield members with their names.""" + memo = set() + modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] + for module_prefix, module in modules: + members = get_members_fn(module) + for k, v in members: + if v is None or v in memo: + continue + memo.add(v) + name = module_prefix + ("." if module_prefix else "") + k + yield (name, v) + + def modules(self) -> Iterator["InfiniCoreModule"]: + r"""Returns an iterator over all modules in the network. + + Yields: + Module: a module in the network + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.modules()): + ... print(idx, '->', m) + + 0 -> Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + ) + 1 -> Linear(in_features=2, out_features=2, bias=True) + + """ + for name, module in self.named_modules(): + yield module + + def named_modules( + self, + memo: Optional[Set["InfiniCoreModule"]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ): + r"""Returns an iterator over all modules in the network, yielding + both the name of the module as well as the module itself. + + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + + Yields: + (str, Module): Tuple of name and module + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.named_modules()): + ... print(idx, '->', m) + + 0 -> ('', Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + )) + 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) + + """ + if memo is None: + memo = set() + if remove_duplicate: + if self in memo: + return + memo.add(self) + yield prefix, self + for name, module in self._modules.items(): + if module is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + # Handle both InfiniCoreModule and torch.nn.Module + if isinstance(module, InfiniCoreModule): + for m in module.named_modules(memo, submodule_prefix, remove_duplicate): + yield m + elif isinstance(module, infinicore.nn.Module): + # For torch.nn.Module, use its named_modules method + # torch.nn.Module.named_modules returns (name, module) tuples + for sub_name, sub_module in module.named_modules( + prefix=submodule_prefix, remove_duplicate=remove_duplicate + ): + yield (sub_name, sub_module) + + def children(self) -> Iterator["InfiniCoreModule"]: + r"""Returns an iterator over immediate children modules. + + Yields: + Module: a child module (can be InfiniCoreModule or torch.nn.Module) + """ + for name, module in self.named_children(): + yield module + + def named_children( + self, + ) -> Iterator[Tuple[str, "InfiniCoreModule"]]: + r"""Returns an iterator over immediate children modules, yielding both + the name of the module as well as the module itself. + + Yields: + (str, Module): Tuple containing a name and child module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, module in model.named_children(): + >>> if name in ['conv4', 'conv5']: + >>> print(module) + + """ + memo = set() + for name, module in self._modules.items(): + if module is not None and module not in memo: + memo.add(module) + yield name, module + + def eval(self: T) -> T: + r"""Sets the module in evaluation mode. + + Returns: + Module: self + """ + pass + + def _apply(self, fn, recurse=True): + raise KeyError("not support") + + def to(self, *args, **kwargs): + raise KeyError("not support") diff --git a/python/infinicore/nn/parameter.py b/python/infinicore/nn/parameter.py new file mode 100644 index 000000000..598f90c1d --- /dev/null +++ b/python/infinicore/nn/parameter.py @@ -0,0 +1,34 @@ +# Copyright (c) 2025, InfiniCore +# +# This file contains modified code derived from PyTorch's `torch.nn.Parameter` +# implementation, which is licensed under the BSD 3-Clause License. +# +# The modifications include adaptations for the InfiniCore framework. +# +# Original PyTorch source: +# https://github.com/pytorch/pytorch/blob/main/torch/nn/parameter.py +# +# Referencing PyTorch v2.4.0 +# +# The use of this file is governed by the BSD 3-Clause License. + + +from ..tensor import Tensor + + +class InfiniCoreParameter(Tensor): + r"""A kind of Tensor that is to be considered a module parameter.""" + + def __init__(self, data=None): + if not isinstance(data, Tensor): + raise ValueError("The `data` variable must be of type `infinicore.Tensor`.") + super().__init__(data._underlying) + + def __repr__(self): + return "Parameter containing:\n" + super().__repr__() + + def __deepcopy__(self, memo): + raise ValueError("not supported!") + + def __reduce_ex__(self, proto): + raise ValueError("not supported!") diff --git a/test/infinicore/nn/Module.py b/test/infinicore/nn/Module.py new file mode 100644 index 000000000..69e341fa2 --- /dev/null +++ b/test/infinicore/nn/Module.py @@ -0,0 +1,80 @@ +# ============================================================ +# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径 +# ============================================================ +import os +import sys + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python/infinicore")) +) + +save_dir = os.path.join(os.path.dirname(__file__), "../../tmp") +os.makedirs(save_dir, exist_ok=True) +save_path = os.path.join(save_dir, "torch_convnet_with_param.safetensors") + + +import infinicore # noqa: E402 +from infinicore.nn import Module # noqa: E402 + + +# ============================================================ +# 1. 定义模型 +# ============================================================ +device_str = "cuda" + + +class InfiniCoreNet(Module): + def __init__(self): + super().__init__() + self.a = infinicore.nn.Parameter( + infinicore.empty( + (1, 2, 3), + dtype=infinicore.float32, + device=infinicore.device(device_str), + ) + ) + self.b = infinicore.nn.Parameter( + infinicore.empty( + (1, 2, 3), + dtype=infinicore.float32, + device=infinicore.device(device_str), + ) + ) + + def forward(self): + return infinicore.add(self.a, self.b) + +infinicore_model_infer = InfiniCoreNet() +# ============================================================ +# 2. 加载权重 +# ============================================================ + +params_dict = { + "a": infinicore.empty( + (1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0) + ), + "b": infinicore.empty( + (1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0) + ), +} +infinicore_model_infer.load_state_dict(params_dict) + +# ============================================================ +# 3. 计算 +# ============================================================ +infinicore_model_out = infinicore_model_infer() +ref_out = infinicore.add(params_dict["a"], params_dict["b"]) + + +# ============================================================ +# 4. 对比结果 +# ============================================================ +print("InfiniCoreModule 与 Torch (CPU) 最大误差: 手动查看 ") +infinicore_model_out.debug() +ref_out.debug() + + +# ============================================================ +# 5. to测试,buffer测试 +# ============================================================ +# 等待添加 diff --git a/test/infinicore/nn/ModuleList.py b/test/infinicore/nn/ModuleList.py new file mode 100644 index 000000000..b7544688a --- /dev/null +++ b/test/infinicore/nn/ModuleList.py @@ -0,0 +1,323 @@ +import os + +# ============================================================ +# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径 +# ============================================================ +import sys + +import safetensors +import safetensors.torch +import torch +import torch.nn as nn + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python/infinicore")) +) + +# 使用临时目录,如果不存在则自动创建 +save_dir = os.path.join(os.path.dirname(__file__), "../../tmp") +os.makedirs(save_dir, exist_ok=True) +save_path = os.path.join(save_dir, "torch_modulelist_with_param.safetensors") + + +def test(): + # ============================================================ + # 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.ModuleList) + # ============================================================ + class TorchModuleListNet(nn.Module): + def __init__(self, in_ch=3, hidden_ch=8, out_ch=3): + super().__init__() + # 使用 torch.nn.ModuleList + self.layers = nn.ModuleList( + [ + nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(hidden_ch), + nn.ReLU(), + nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(hidden_ch), + nn.ReLU(), + nn.Conv2d(hidden_ch, out_ch, kernel_size=1), + ] + ) + + # 自定义 Parameter + self.scale = nn.Parameter(torch.ones(1) * 0.5) + self.register_buffer("offset", torch.tensor(0.1)) + + def forward(self, x): + # 遍历 ModuleList 中的所有层 + for layer in self.layers: + x = layer(x) + # 应用自定义参数和 buffer + x = x * self.scale + self.offset + return x + + # ===== 保存 Torch 模型 ===== + torch_model = TorchModuleListNet() + torch_state_dict = torch_model.state_dict() + safetensors.torch.save_file(torch_state_dict, save_path) + print("✓ PyTorch 模型已保存") + + # ============================================================ + # 2. 使用 torch 方式加载并推理 + # ============================================================ + + torch_model_infer = TorchModuleListNet() + torch_model_infer.load_state_dict(safetensors.torch.load_file(save_path)) + torch_model_infer.eval() + + input = torch.rand(1, 3, 8, 8) + torch_model_out = torch_model_infer(input) + print("✓ Torch 输出:", torch_model_out.detach().numpy().mean()) + + # ============================================================ + # 3. 使用 ModuleList 加载并推理 + # ============================================================ + + from nn.modules import Module, ModuleList + + class InfiniCoreModuleListNet(Module): + def __init__(self, in_ch=3, hidden_ch=8, out_ch=3): + super().__init__() + # 使用 ModuleList + self.layers = ModuleList( + [ + nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(hidden_ch), + nn.ReLU(), + nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(hidden_ch), + nn.ReLU(), + nn.Conv2d(hidden_ch, out_ch, kernel_size=1), + ] + ) + + # 保持与 Torch 模型一致的自定义参数和 buffer + self.scale = nn.Parameter(torch.ones(1) * 0.5) + self.register_buffer("offset", torch.tensor(0.1)) + + def forward(self, x): + # 遍历 ModuleList 中的所有层 + for layer in self.layers: + x = layer(x) + x = x * self.scale + self.offset + return x + + # ===== 使用 ModuleListNet 读取 safetensors 并推理 ===== + infinicore_model_infer = InfiniCoreModuleListNet() + infinicore_model_infer.load_state_dict(safetensors.torch.load_file(save_path)) + infinicore_model_infer.eval() + + infinicore_model_out = infinicore_model_infer.forward(input) + print("✓ InfiniCore 输出:", infinicore_model_out.detach().numpy().mean()) + + # ============================================================ + # 4. 对比结果 + # ============================================================ + + diff = (infinicore_model_out - torch_model_out).abs().max().item() + print(f"✓ ModuleList 与 Torch 最大误差: {diff:.8f}") + if diff < 1e-9: + print("✓ ModuleList 与 Torch 精度一致.") + else: + print("✗ ModuleList 与 Torch 精度存在差异.") + + # ============================================================ + # 5. 测试 ModuleList 的基本功能 + # ============================================================ + + print("\n=== 测试 ModuleList 基本功能 ===") + + # 测试 1: 创建和访问 + module_list = ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)]) + + print(f"✓ 创建 ModuleList,长度: {len(module_list)}") + print(f"✓ 访问第一个模块: {type(module_list[0]).__name__}") + print(f"✓ 访问第二个模块: {type(module_list[1]).__name__}") + + # 测试 2: append + module_list.append(nn.Softmax(dim=-1)) + print(f"✓ append 后长度: {len(module_list)}") + + # 测试 3: extend + module_list.extend([nn.Dropout(0.1), nn.Linear(5, 1)]) + print(f"✓ extend 后长度: {len(module_list)}") + + # 测试 4: 迭代 + print("✓ 迭代 ModuleList:") + for i, module in enumerate(module_list): + print(f" [{i}] {type(module).__name__}") + + # 测试 5: 索引访问 + print(f"✓ 索引访问 module_list[0]: {type(module_list[0]).__name__}") + + # 测试 6: state_dict + state_dict = module_list.state_dict() + print(f"✓ state_dict 键数量: {len(state_dict)}") + print(f"✓ state_dict 包含模块参数: {any('0.' in k for k in state_dict.keys())}") + + # 测试 7: 使用 ModuleList 的模型 + class TestNet(Module): + def __init__(self): + super().__init__() + self.layers = ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + test_model = TestNet() + test_input = torch.randn(2, 10) + test_output = test_model.forward(test_input) + print(f"✓ TestNet 输入形状: {test_input.shape}, 输出形状: {test_output.shape}") + + # 测试 8: __add__ 方法 + ml1 = ModuleList([nn.Linear(10, 5), nn.ReLU()]) + ml2 = ModuleList([nn.Linear(5, 3), nn.Sigmoid()]) + ml3 = ml1 + ml2 + print(f"✓ __add__ 方法测试: {len(ml1)} + {len(ml2)} = {len(ml3)}") + assert len(ml3) == 4, "合并后的长度应该为 4" + + # 测试 9: pop 方法 + ml4 = ModuleList([nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 3)]) + popped = ml4.pop() + print( + f"✓ pop 方法测试: 弹出后长度 {len(ml4)}, 弹出模块类型 {type(popped).__name__}" + ) + assert len(ml4) == 2, "pop 后长度应该为 2" + assert isinstance(popped, nn.Linear), "弹出的应该是 Linear 模块" + + # 测试 10: __repr__ 方法 + ml5 = ModuleList([nn.Linear(10, 5), nn.ReLU()]) + repr_str = repr(ml5) + print(f"✓ __repr__ 方法测试: 输出包含类名和模块信息") + assert "ModuleList" in repr_str or "InfiniCoreModuleList" in repr_str, ( + "repr 应该包含类名" + ) + assert "Linear" in repr_str, "repr 应该包含模块信息" + print(repr_str) + + print("\n=== 所有测试通过! ===") + + # ============================================================ + # 6. 前向传播集成测试(参考 infinicore_nn_test.py) + # ============================================================ + + print("\n=== 前向传播集成测试 ===") + + # 使用 ModuleList 创建一个简单的模型 + class TorchModuleListModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList( + [nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)] + ) + self.scale = nn.Parameter(torch.ones(1) * 0.5) + self.register_buffer("offset", torch.tensor(0.1)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = x * self.scale + self.offset + return x + + class InfiniCoreModuleListModel(Module): + def __init__(self): + super().__init__() + self.layers = ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)]) + self.scale = nn.Parameter(torch.ones(1) * 0.5) + self.register_buffer("offset", torch.tensor(0.1)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = x * self.scale + self.offset + return x + + # 创建模型 + torch_model_forward = TorchModuleListModel() + infinicore_model_forward = InfiniCoreModuleListModel() + + # 复制权重(确保初始权重一致) + infinicore_model_forward.load_state_dict( + torch_model_forward.state_dict(), strict=False + ) + + # 设置为评估模式 + torch_model_forward.eval() + infinicore_model_forward.eval() + + # 创建测试输入 + test_input = torch.randn(2, 10) + + # 前向传播 + with torch.no_grad(): + torch_output = torch_model_forward(test_input) + infinicore_output = infinicore_model_forward.forward(test_input) + + # 对比结果 + diff = (infinicore_output - torch_output).abs().max().item() + print(f"✓ 前向传播测试 - 输入形状: {test_input.shape}") + print( + f"✓ Torch 输出形状: {torch_output.shape}, 均值: {torch_output.detach().numpy().mean():.8f}" + ) + print( + f"✓ InfiniCore 输出形状: {infinicore_output.shape}, 均值: {infinicore_output.detach().numpy().mean():.8f}" + ) + print(f"✓ 最大误差: {diff:.8f}") + + if diff < 1e-9: + print("✓ 前向传播集成测试通过:ModuleList 与 Torch ModuleList 结果一致!") + else: + print("✗ 前向传播集成测试失败:存在差异") + + # ============================================================ + # 7. 混合模块兼容性测试(PyTorch + InfiniCore 模块混合使用) + # ============================================================ + + print("\n=== 混合模块兼容性测试 ===") + + # 创建一个自定义的 InfiniCore 模块 + class CustomLinear(Module): + def __init__(self, in_features, out_features): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + self.bias = nn.Parameter(torch.randn(out_features)) + + def forward(self, x): + return x @ self.weight.t() + self.bias + + # 创建混合 ModuleList(包含 PyTorch 模块和 InfiniCore 模块) + mixed_list = ModuleList( + [ + nn.Linear(10, 5), # PyTorch 模块 + CustomLinear(5, 3), # 自定义 InfiniCore 模块 + nn.ReLU(), # PyTorch 模块 + ] + ) + + print(f"✓ 创建混合 ModuleList,长度: {len(mixed_list)}") + print(f"✓ 模块类型: {[type(m).__name__ for m in mixed_list]}") + + # 测试参数注册 + param_count = sum(1 for _ in mixed_list.parameters()) + print(f"✓ 参数数量: {param_count}") + assert param_count == 4, ( + f"参数数量应该为 4 (Linear: weight+bias, CustomLinear: weight+bias), 实际为 {param_count}" + ) + + # 测试 state_dict + mixed_state_dict = mixed_list.state_dict() + print(f"✓ state_dict 键数量: {len(mixed_state_dict)}") + assert len(mixed_state_dict) >= 4, "state_dict 应该包含至少 4 个参数" + + # 测试前向传播 + test_input_mixed = torch.randn(2, 10) + with torch.no_grad(): + x = test_input_mixed + for module in mixed_list: + x = module.forward(x) + print(f"✓ 混合模块前向传播成功,输出形状: {x.shape}") + + print("✓ 混合模块兼容性测试通过!") diff --git a/test/infinicore/nn/Parameter.py b/test/infinicore/nn/Parameter.py new file mode 100644 index 000000000..1a208dd86 --- /dev/null +++ b/test/infinicore/nn/Parameter.py @@ -0,0 +1,148 @@ +# ============================================================ +# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径 +# ============================================================ +import os +import sys + + +import torch +import torch.nn as nn + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python/infinicore")) +) + +save_dir = os.path.join(os.path.dirname(__file__), "../../tmp") +os.makedirs(save_dir, exist_ok=True) +save_path = os.path.join(save_dir, "infinicore_parameter_test.safetensors") + + +import infinicore # noqa: E402 +from infinicore.nn import Module, Parameter # noqa: E402 + +device_str = "cuda" + + +class InfiniCoreParameterNet(Module): + def __init__(self): + super().__init__() + self.a = infinicore.nn.Parameter( + infinicore.empty( + (1, 2, 3), dtype=infinicore.float32, device=infinicore.device("cpu", 0) + ) + ) + + def forward(self, x): + return infinicore.add(self.a, x) + + +infinicore_model_infer = InfiniCoreParameterNet() +# ============================================================ +# 2. 加载权重 +# ============================================================ +params_dict = { + "a": infinicore.empty( + (1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0) + ) +} +infinicore_model_infer.load_state_dict(params_dict) + + +# ============================================================ +# 3. 计算 +# ============================================================ +x = infinicore.empty( + (1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0) +) + +infinicore_model_out = infinicore_model_infer(x) +ref_out = infinicore.add(params_dict["a"], x) + +# ============================================================ +# 4. 对比结果 +# ============================================================ +print("InfiniCoreModule 与 Torch (CPU) 最大误差: 手动查看 ") +infinicore_model_out.debug() +ref_out.debug() + + +# ============================================================ +# 5. 测试 Parameter 的基本功能 +# ============================================================ + +print("\n=== 测试 Parameter 基本功能 ===") + +# 测试 1: 创建 Parameter +param1 = infinicore.nn.Parameter( + infinicore.empty( + (1, 2, 3), dtype=infinicore.float32, device=infinicore.device(device_str, 0) + ) +) +print(f"✓ 创建 Parameter,形状: {param1.shape}") +# 检查是否是 Parameter 类型(可能是 InfiniCoreParameter 的别名) + +assert isinstance(param1, infinicore.nn.Parameter), "应该是 Parameter 类型" +assert isinstance(param1, infinicore.Tensor), "应该是 torch.Tensor 的子类" + + +# 测试 3: 自动注册到 Module +class TestModule(Module): + def __init__(self): + super().__init__() + self.weight = infinicore.nn.Parameter( + infinicore.empty( + (1, 2, 3), + dtype=infinicore.float32, + device=infinicore.device(device_str), + ) + ) + self.bias = infinicore.nn.Parameter( + infinicore.empty( + (1, 2, 3), + dtype=infinicore.float32, + device=infinicore.device(device_str), + ) + ) + + +test_module = TestModule() +param_count = sum(1 for _ in test_module.parameters()) +print(f"✓ 自动注册到 Module,参数数量: {param_count}") +assert param_count == 2, f"应该有 2 个参数,实际为 {param_count}" + +# 测试 4: 参数访问 +assert test_module.weight is not None, "weight 应该可以访问" +assert test_module.bias is not None, "bias 应该可以访问" +print("✓ 参数可以通过属性访问") + +# 测试 5: state_dict +state_dict = test_module.state_dict() +print(f"✓ state_dict 键数量: {len(state_dict)}") +assert "weight" in state_dict, "state_dict 应该包含 weight" +assert "bias" in state_dict, "state_dict 应该包含 bias" +print(f"✓ state_dict 键: {list(state_dict.keys())}") + +# 测试 6: __repr__ +repr_str = repr(param1) +print(f"✓ __repr__ 方法: 输出包含类名") +assert "Parameter" in repr_str or "InfiniCoreParameter" in repr_str, "repr 应该包含类名" +print(repr_str[:100] + "...") + + +# 测试 9: 从 None 创建 +# param_empty = Parameter(None) +# print(f"✓ 从 None 创建 Parameter,形状: {param_empty.shape}") +# assert param_empty.shape == torch.Size([0]), "从 None 创建应该是空张量" + + +# 测试 10: 深拷贝 +# import copy + +# param_copy = copy.deepcopy(param1) +# print(f"✓ 深拷贝 Parameter,形状: {param_copy.shape}") +# assert param_copy.shape == param1.shape, "深拷贝后形状应该相同" +# assert not torch.equal(param_copy, param1) or id(param_copy) != id(param1), ( +# "深拷贝应该是新对象" +# ) + +print("\n=== 所有测试通过! ===")