@@ -16,11 +16,40 @@ class NetworkProto(typing.Protocol):
16
16
def __call__ (self , x : torch .Tensor ) -> torch .Tensor :
17
17
"""
18
18
Calculate the amplitude for the given configurations.
19
+
20
+ Parameters
21
+ ----------
22
+ x : torch.Tensor
23
+ The configurations to calculate the amplitude for.
24
+ The configurations are a two-dimensional uint8 tensor with first dimension equal to some batch size.
25
+ The second dimension contains occupation for each qubit which is bitwise encoded.
26
+
27
+ Returns
28
+ -------
29
+ torch.Tensor
30
+ The amplitudes of the configurations.
31
+ The amplitudes are a one-dimensional complex tensor with the only dimension equal to the batch size.
19
32
"""
20
33
21
34
def generate_unique (self , batch_size : int , block_num : int = 1 ) -> tuple [torch .Tensor , torch .Tensor , None , None ]:
22
35
"""
23
36
Generate a batch of unique configurations.
37
+
38
+ Parameters
39
+ ----------
40
+ batch_size : int
41
+ The number of configurations to generate.
42
+ block_num : int, default=1
43
+ The number of batch block to generate. It is used to split the batch into smaller parts to avoid memory issues.
44
+
45
+ Returns
46
+ -------
47
+ tuple[torch.Tensor, torch.Tensor, None, None]
48
+ A tuple containing the generated configurations, their amplitudes, and two None values.
49
+ The configurations are a two-dimensional uint8 tensor with first dimension equal to `batch_size`.
50
+ The second dimension contains occupation for each qubit which is bitwise encoded.
51
+ The amplitudes are a one-dimensional complex tensor with the only dimension equal to `batch_size`.
52
+ The last two None values are reserved for future use.
24
53
"""
25
54
26
55
def load_state_dict (self , data : dict [str , torch .Tensor ]) -> typing .Any :
0 commit comments