Skip to content

Commit 0a02d7a

Browse files
committed
Use itertools.pairwise when constructing MLP network.
As mentioned by USTC-KnowledgeComputingLab/qmb#14 (comment) we could use `itertools.pairwise` to improve the readability of codes during constructing MLP network. PR Tracking at: USTC-KnowledgeComputingLab/qmb#23
2 parents 50bcf1b + 84d571f commit 0a02d7a

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

qmb/crossmlp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
This file implements a cross MLP network.
33
"""
44

5+
import itertools
56
import typing
67
import torch
78
from .bitspack import unpack_int
@@ -48,7 +49,7 @@ def __init__(self, dim_input: int, dim_output: int, hidden_size: tuple[int, ...]
4849
self.depth: int = len(hidden_size)
4950

5051
dimensions: list[int] = [dim_input] + list(hidden_size) + [dim_output]
51-
linears: list[torch.nn.Module] = [select_linear_layer(i, j) for i, j in zip(dimensions[:-1], dimensions[1:])]
52+
linears: list[torch.nn.Module] = [select_linear_layer(i, j) for i, j in itertools.pairwise(dimensions)]
5253
modules: list[torch.nn.Module] = [module for linear in linears for module in (linear, torch.nn.SiLU())][:-1]
5354
self.model: torch.nn.Module = torch.nn.Sequential(*modules)
5455

qmb/mlp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
This file implements the MLP network from https://arxiv.org/pdf/2109.12606 with the sampling method introduced in https://arxiv.org/pdf/2408.07625.
33
"""
44

5+
import itertools
56
import torch
67
from .bitspack import pack_int, unpack_int
78

@@ -47,7 +48,7 @@ def __init__(self, dim_input: int, dim_output: int, hidden_size: tuple[int, ...]
4748
self.depth: int = len(hidden_size)
4849

4950
dimensions: list[int] = [dim_input] + list(hidden_size) + [dim_output]
50-
linears: list[torch.nn.Module] = [select_linear_layer(i, j) for i, j in zip(dimensions[:-1], dimensions[1:])]
51+
linears: list[torch.nn.Module] = [select_linear_layer(i, j) for i, j in itertools.pairwise(dimensions)]
5152
modules: list[torch.nn.Module] = [module for linear in linears for module in (linear, torch.nn.SiLU())][:-1]
5253
self.model: torch.nn.Module = torch.nn.Sequential(*modules)
5354

0 commit comments

Comments
 (0)