Skip to content

Commit ac2e41b

Browse files
committed
Add documents for functions in network proto.
PR: USTC-KnowledgeComputingLab/qmb#33 Signed-off-by: Hao Zhang <[email protected]>
2 parents 6f670c5 + 4981b0d commit ac2e41b

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

qmb/model_dict.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,40 @@ class NetworkProto(typing.Protocol):
1616
def __call__(self, x: torch.Tensor) -> torch.Tensor:
1717
"""
1818
Calculate the amplitude for the given configurations.
19+
20+
Parameters
21+
----------
22+
x : torch.Tensor
23+
The configurations to calculate the amplitude for.
24+
The configurations are a two-dimensional uint8 tensor with first dimension equal to some batch size.
25+
The second dimension contains occupation for each qubit which is bitwise encoded.
26+
27+
Returns
28+
-------
29+
torch.Tensor
30+
The amplitudes of the configurations.
31+
The amplitudes are a one-dimensional complex tensor with the only dimension equal to the batch size.
1932
"""
2033

2134
def generate_unique(self, batch_size: int, block_num: int = 1) -> tuple[torch.Tensor, torch.Tensor, None, None]:
2235
"""
2336
Generate a batch of unique configurations.
37+
38+
Parameters
39+
----------
40+
batch_size : int
41+
The number of configurations to generate.
42+
block_num : int, default=1
43+
The number of batch block to generate. It is used to split the batch into smaller parts to avoid memory issues.
44+
45+
Returns
46+
-------
47+
tuple[torch.Tensor, torch.Tensor, None, None]
48+
A tuple containing the generated configurations, their amplitudes, and two None values.
49+
The configurations are a two-dimensional uint8 tensor with first dimension equal to `batch_size`.
50+
The second dimension contains occupation for each qubit which is bitwise encoded.
51+
The amplitudes are a one-dimensional complex tensor with the only dimension equal to `batch_size`.
52+
The last two None values are reserved for future use.
2453
"""
2554

2655
def load_state_dict(self, data: dict[str, torch.Tensor]) -> typing.Any:

0 commit comments

Comments
 (0)