|
17 | 17 | import numpy as np
|
18 | 18 | import tabulate
|
19 | 19 | import torch
|
20 |
| - |
| 20 | +import math |
21 | 21 | from fbgemm_gpu.split_embedding_configs import SparseType
|
22 | 22 | from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
23 | 23 | BoundsCheckMode,
|
@@ -101,45 +101,63 @@ def generate_requests(
|
101 | 101 |
|
102 | 102 |
|
103 | 103 | # pyre-fixme[3]: Return type must be annotated.
|
| 104 | + |
104 | 105 | def _get_random_tensor(
|
105 | 106 | num_ads: int,
|
106 | 107 | embedding_dimension: int,
|
107 | 108 | ads_tables: int,
|
108 | 109 | data_type: str,
|
109 | 110 | gpu_idx: int,
|
110 | 111 | include_quantization: bool,
|
| 112 | + use_pitched: bool = True, |
| 113 | + alignment: int = 256, # alignment in bytes |
111 | 114 | ):
|
| 115 | + device = torch.device(f"cuda:{gpu_idx}") |
| 116 | + |
112 | 117 | if data_type == "FP16" or include_quantization:
|
113 |
| - result_tensor = torch.randn( |
114 |
| - num_ads, |
115 |
| - embedding_dimension * ads_tables, |
116 |
| - dtype=torch.float16, |
117 |
| - device=torch.device(f"cuda:{gpu_idx}"), |
118 |
| - ) |
| 118 | + dtype = torch.float16 |
| 119 | + width_elems = embedding_dimension * ads_tables |
| 120 | + elem_size = torch.finfo(dtype).bits // 8 |
| 121 | + |
| 122 | + if use_pitched: |
| 123 | + width_bytes = width_elems * elem_size |
| 124 | + pitch_bytes = math.ceil(width_bytes / alignment) * alignment |
| 125 | + pitch_elems = pitch_bytes // elem_size |
| 126 | + storage = torch.empty((num_ads, pitch_elems), dtype=dtype, device=device) |
| 127 | + result_tensor = storage[:, :width_elems] # logical view |
| 128 | + else: |
| 129 | + result_tensor = torch.randn(num_ads, width_elems, dtype=dtype, device=device) |
| 130 | + |
119 | 131 | elif data_type == "INT8":
|
120 |
| - assert ( |
121 |
| - embedding_dimension % 2 |
122 |
| - ) == 0, "needs to align to 2 bytes (half type size) for INT8" |
123 |
| - result_tensor = torch.randint( |
124 |
| - 0, |
125 |
| - 255, |
126 |
| - # 2 FP16 numbers for scale and bias, total of 4 bytes overhead |
127 |
| - size=(num_ads, (embedding_dimension + 4) * ads_tables), |
128 |
| - dtype=torch.uint8, |
129 |
| - device=torch.device(f"cuda:{gpu_idx}"), |
130 |
| - ) |
| 132 | + assert embedding_dimension % 2 == 0, "needs to align to 2 bytes for INT8" |
| 133 | + dtype = torch.uint8 |
| 134 | + width_elems = (embedding_dimension + 4) * ads_tables |
| 135 | + elem_size = 1 |
| 136 | + |
| 137 | + if use_pitched: |
| 138 | + width_bytes = width_elems * elem_size |
| 139 | + pitch_bytes = math.ceil(width_bytes / alignment) * alignment |
| 140 | + pitch_elems = pitch_bytes // elem_size |
| 141 | + storage = torch.randint(0, 255, (num_ads, pitch_elems), dtype=dtype, device=device) |
| 142 | + result_tensor = storage[:, :width_elems] |
| 143 | + else: |
| 144 | + result_tensor = torch.randint(0, 255, (num_ads, width_elems), dtype=dtype, device=device) |
| 145 | + |
131 | 146 | elif data_type == "INT4":
|
132 |
| - assert ( |
133 |
| - embedding_dimension % 4 |
134 |
| - ) == 0, "needs to align to 2 bytes (half type size) for INT4" |
135 |
| - result_tensor = torch.randint( |
136 |
| - 0, |
137 |
| - 255, |
138 |
| - # Using torch.uint8 for int4 storage |
139 |
| - size=(num_ads, (embedding_dimension // 2 + 4) * ads_tables), |
140 |
| - dtype=torch.uint8, |
141 |
| - device=torch.device(f"cuda:{gpu_idx}"), |
142 |
| - ) |
| 147 | + assert embedding_dimension % 4 == 0, "needs to align to 2 bytes for INT4" |
| 148 | + dtype = torch.uint8 |
| 149 | + width_elems = (embedding_dimension // 2 + 4) * ads_tables |
| 150 | + elem_size = 1 |
| 151 | + |
| 152 | + if use_pitched: |
| 153 | + width_bytes = width_elems * elem_size |
| 154 | + pitch_bytes = math.ceil(width_bytes / alignment) * alignment |
| 155 | + pitch_elems = pitch_bytes // elem_size |
| 156 | + storage = torch.randint(0, 255, (num_ads, pitch_elems), dtype=dtype, device=device) |
| 157 | + result_tensor = storage[:, :width_elems] |
| 158 | + else: |
| 159 | + result_tensor = torch.randint(0, 255, (num_ads, width_elems), dtype=dtype, device=device) |
| 160 | + |
143 | 161 | else:
|
144 | 162 | raise ValueError
|
145 | 163 |
|
|
0 commit comments