|
1 | | -"""Container component for composing multiple components, such as Sequential.""" |
2 | | - |
3 | | -from collections import OrderedDict |
| 1 | +""" |
| 2 | +Container component for composing multiple components, such as Sequential |
| 3 | +and ComponentList. |
| 4 | +
|
| 5 | +This design draws inspiration from PyTorch’s modular |
| 6 | +container patterns, including `nn.Sequential` and `nn.ModuleList`. The |
| 7 | +`Container` component allows for grouping several components into one, enabling |
| 8 | +flexible and reusable model architectures. |
| 9 | +
|
| 10 | +Design Motivation: |
| 11 | +------------------- |
| 12 | +This implementation follows the same principles as PyTorch’s component-based |
| 13 | +design, encouraging modularity, reusability, and extensibility. The `Container` |
| 14 | +component provides an easy way to manage multiple layers or other components, |
| 15 | +while ensuring that their parameters are properly registered and updated during |
| 16 | +training. |
| 17 | +
|
| 18 | +Credits: |
| 19 | +--------- |
| 20 | +The design of this component takes inspiration from the PyTorch project |
| 21 | +(https://pytorch.org). PyTorch is an open-source deep learning framework, |
| 22 | +licensed under a BSD-style license. Although this code is not part of the |
| 23 | +official PyTorch library, it mirrors the same design principles. |
| 24 | +
|
| 25 | +For more details on PyTorch’s licensing, refer to: |
| 26 | +https://github.com/pytorch/pytorch/blob/main/LICENSE |
| 27 | +
|
| 28 | +Usage Example: |
| 29 | +-------------- |
| 30 | + class MyModule(nn.Module): |
| 31 | + def __init__(self): |
| 32 | + super().__init__() |
| 33 | +
|
| 34 | + self.model = nn.Sequential( |
| 35 | + nn.Conv2d(1,20,5), |
| 36 | + nn.ReLU(), |
| 37 | + nn.Conv2d(20,64,5), |
| 38 | + nn.ReLU() |
| 39 | + ) |
| 40 | + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) |
| 41 | +
|
| 42 | + def forward(self, x): |
| 43 | + # ModuleList can act as an iterable, or be indexed using ints |
| 44 | + for i, l in enumerate(self.linears): |
| 45 | + x = self.linears[i // 2](x) + l(x) |
| 46 | + return x |
| 47 | +
|
| 48 | +""" |
| 49 | + |
| 50 | +from collections import OrderedDict, abc as container_abcs |
4 | 51 | import operator |
5 | | -from itertools import islice |
6 | | -from typing import TypeVar, Dict, Union, Iterable, Iterator, Any, overload |
| 52 | +from itertools import islice, chain |
| 53 | +from typing import TypeVar, Dict, Union, Iterable, Iterator, Any, overload, Optional |
7 | 54 |
|
8 | 55 | from adalflow.core.component import Component |
9 | 56 |
|
10 | 57 | T = TypeVar("T", bound=Component) |
11 | 58 |
|
| 59 | +__all__ = ["Sequential", "ComponentList"] |
| 60 | + |
12 | 61 |
|
13 | 62 | class Sequential(Component): |
14 | 63 | __doc__ = r"""A sequential container. |
@@ -311,3 +360,177 @@ def extend(self, components: Iterable[Component]) -> "Sequential": |
311 | 360 | for component in components: |
312 | 361 | self.append(component) |
313 | 362 | return self |
| 363 | + |
| 364 | + |
| 365 | +def _addindent(s_: str, numSpaces: int): |
| 366 | + s = s_.split("\n") |
| 367 | + # don't do anything for single-line stuff |
| 368 | + if len(s) == 1: |
| 369 | + return s_ |
| 370 | + first = s.pop(0) |
| 371 | + s = [(numSpaces * " ") + line for line in s] |
| 372 | + s = "\n".join(s) |
| 373 | + s = first + "\n" + s |
| 374 | + return s |
| 375 | + |
| 376 | + |
| 377 | +class ComponentList(Component): |
| 378 | + __doc__ = r"""Holds subcomponents in a list. |
| 379 | +
|
| 380 | + :class:`adalflow.core.ComponentList` can be indexed like a regular Python list, but |
| 381 | + the components it holds are properly registered, and will be visible by all |
| 382 | + :class:`adalflow.core.Component` methods. |
| 383 | +
|
| 384 | + Args: |
| 385 | + components (iterable, optional): an iterable of components to add |
| 386 | +
|
| 387 | + Examples: |
| 388 | +
|
| 389 | + .. code-block:: python |
| 390 | +
|
| 391 | + # Example of how to use ComponentList |
| 392 | + class MyComponents(Component): |
| 393 | + def __init__(self): |
| 394 | + super().__init__() |
| 395 | + self.llms = ComponentList([adal.Generator() for i in range(10)]) |
| 396 | +
|
| 397 | + def forward(self, x): |
| 398 | + for layer in self.layers: |
| 399 | + x = layer(x) |
| 400 | + return x |
| 401 | + """ |
| 402 | + _components: Dict[str, Component] = OrderedDict() |
| 403 | + |
| 404 | + def __init__(self, components: Optional[Iterable[Component]] = None) -> None: |
| 405 | + super().__init__() |
| 406 | + if components is not None: |
| 407 | + self += components |
| 408 | + |
| 409 | + def _get_abs_string_index(self, idx): |
| 410 | + """Get the absolute index as a string.""" |
| 411 | + idx = operator.index(idx) |
| 412 | + if not (-len(self) <= idx < len(self)): |
| 413 | + raise IndexError(f"index {idx} is out of range") |
| 414 | + if idx < 0: |
| 415 | + idx += len(self) |
| 416 | + return str(idx) |
| 417 | + |
| 418 | + def __getitem__(self, idx: Union[int, slice]) -> Union[Component, "ComponentList"]: |
| 419 | + """Retrieve a component or a slice of components.""" |
| 420 | + if isinstance(idx, slice): |
| 421 | + return self.__class__(list(self._components.values())[idx]) |
| 422 | + else: |
| 423 | + return self._components[self._get_abs_string_index(idx)] |
| 424 | + |
| 425 | + def __setitem__(self, idx: int, component: Component) -> None: |
| 426 | + """Set a component at the given index.""" |
| 427 | + idx = self._get_abs_string_index(idx) |
| 428 | + return setattr(self, str(idx), component) |
| 429 | + |
| 430 | + def __delitem__(self, idx: Union[int, slice]) -> None: |
| 431 | + """Delete a component or a slice of components.""" |
| 432 | + if isinstance(idx, slice): |
| 433 | + for k in range(len(self._components))[idx]: |
| 434 | + delattr(self, str(k)) |
| 435 | + else: |
| 436 | + delattr(self, self._get_abs_string_index(idx)) |
| 437 | + # To preserve numbering, self._components is being reconstructed with modules after deletion |
| 438 | + str_indices = [str(i) for i in range(len(self._components))] |
| 439 | + self._components = OrderedDict( |
| 440 | + list(zip(str_indices, self._components.values())) |
| 441 | + ) |
| 442 | + |
| 443 | + def __len__(self) -> int: |
| 444 | + """Return the number of components.""" |
| 445 | + return len(self._components) |
| 446 | + |
| 447 | + def __iter__(self) -> Iterator[Component]: |
| 448 | + """Iterate over the components.""" |
| 449 | + return iter(self._components.values()) |
| 450 | + |
| 451 | + def __iadd__(self, components: Iterable[Component]) -> "ComponentList": |
| 452 | + """Add multiple components using the `+=` operator.""" |
| 453 | + |
| 454 | + return self.extend(components) |
| 455 | + |
| 456 | + def __add__(self, other: Iterable[Component]) -> "ComponentList": |
| 457 | + """Concatenate two ComponentLists.""" |
| 458 | + |
| 459 | + combined = ComponentList() |
| 460 | + for i, component in enumerate(chain(self, other)): |
| 461 | + combined.add_component(str(i), component) |
| 462 | + return combined |
| 463 | + |
| 464 | + def __repr__(self): |
| 465 | + """Return a custom repr for ModuleList that compresses repeated module representations.""" |
| 466 | + list_of_reprs = [repr(item) for item in self] |
| 467 | + if len(list_of_reprs) == 0: |
| 468 | + return self._get_name() + "()" |
| 469 | + |
| 470 | + start_end_indices = [[0, 0]] |
| 471 | + repeated_blocks = [list_of_reprs[0]] |
| 472 | + for i, r in enumerate(list_of_reprs[1:], 1): |
| 473 | + if r == repeated_blocks[-1]: |
| 474 | + start_end_indices[-1][1] += 1 |
| 475 | + continue |
| 476 | + |
| 477 | + start_end_indices.append([i, i]) |
| 478 | + repeated_blocks.append(r) |
| 479 | + |
| 480 | + lines = [] |
| 481 | + main_str = self._get_name() + "(" |
| 482 | + for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): |
| 483 | + local_repr = f"({start_id}): {b}" # default repr |
| 484 | + |
| 485 | + if start_id != end_id: |
| 486 | + n = end_id - start_id + 1 |
| 487 | + local_repr = f"({start_id}-{end_id}): {n} x {b}" |
| 488 | + |
| 489 | + local_repr = _addindent(local_repr, 2) |
| 490 | + lines.append(local_repr) |
| 491 | + |
| 492 | + main_str += "\n " + "\n ".join(lines) + "\n" |
| 493 | + main_str += ")" |
| 494 | + return main_str |
| 495 | + |
| 496 | + def __dir__(self): |
| 497 | + keys = super().__dir__() |
| 498 | + keys = [key for key in keys if not key.isdigit()] |
| 499 | + return keys |
| 500 | + |
| 501 | + def insert(self, index: int, component: Component) -> None: |
| 502 | + """Insert a component at the specified index.""" |
| 503 | + for i in range(len(self._components), index, -1): |
| 504 | + self._components[str(i)] = self._components[str(i - 1)] |
| 505 | + self._components[str(index)] = component |
| 506 | + |
| 507 | + def pop(self, index: Union[int, slice]) -> Component: |
| 508 | + """Remove and return a component at the given index.""" |
| 509 | + component = self[index] |
| 510 | + del self[index] |
| 511 | + return component |
| 512 | + |
| 513 | + def append(self, component: Component) -> "ComponentList": |
| 514 | + """Append a component to the list.""" |
| 515 | + # self._components[str(len(self))] = component |
| 516 | + self.add_component(str(len(self)), component) |
| 517 | + return self |
| 518 | + |
| 519 | + def extend(self, components: Iterable[Component]) -> "ComponentList": |
| 520 | + """Extend the list by appending multiple components.""" |
| 521 | + # for component in components: |
| 522 | + # self.append(component) |
| 523 | + # return self |
| 524 | + |
| 525 | + if not isinstance(components, container_abcs.Iterable): |
| 526 | + raise TypeError( |
| 527 | + "ModuleList.extend should be called with an " |
| 528 | + "iterable, but got " + type(components).__name__ |
| 529 | + ) |
| 530 | + offset = len(self) |
| 531 | + for i, component in enumerate(components): |
| 532 | + self.add_component(str(offset + i), component) |
| 533 | + return self |
| 534 | + |
| 535 | + |
| 536 | +# TODO: need to do the same to ParameterList and ParameterDict, ModuleDict |
0 commit comments