16
16
import numpy as np
17
17
import tabulate
18
18
import torch
19
- import math
20
19
from fbgemm_gpu .split_embedding_configs import SparseType
21
20
from fbgemm_gpu .split_table_batched_embeddings_ops_common import (
22
21
BoundsCheckMode ,
@@ -101,6 +100,7 @@ def generate_requests(
101
100
102
101
# pyre-fixme[3]: Return type must be annotated.
103
102
103
+
104
104
def _get_random_tensor (
105
105
num_ads : int ,
106
106
embedding_dimension : int ,
@@ -109,7 +109,7 @@ def _get_random_tensor(
109
109
gpu_idx : int ,
110
110
include_quantization : bool ,
111
111
use_pitched : bool = True ,
112
- alignment : int = 256 , # alignment in bytes
112
+ alignment : int = 256 , # alignment in bytes
113
113
):
114
114
device = torch .device (f"cuda:{ gpu_idx } " )
115
115
@@ -120,12 +120,14 @@ def _get_random_tensor(
120
120
121
121
if use_pitched :
122
122
width_bytes = width_elems * elem_size
123
- pitch_bytes = math .ceil (width_bytes / alignment ) * alignment
123
+ pitch_bytes = int ( np .ceil (width_bytes / alignment ) * alignment )
124
124
pitch_elems = pitch_bytes // elem_size
125
125
storage = torch .empty ((num_ads , pitch_elems ), dtype = dtype , device = device )
126
126
result_tensor = storage [:, :width_elems ] # logical view
127
127
else :
128
- result_tensor = torch .randn (num_ads , width_elems , dtype = dtype , device = device )
128
+ result_tensor = torch .randn (
129
+ num_ads , width_elems , dtype = dtype , device = device
130
+ )
129
131
130
132
elif data_type == "INT8" :
131
133
assert embedding_dimension % 2 == 0 , "needs to align to 2 bytes for INT8"
@@ -135,12 +137,16 @@ def _get_random_tensor(
135
137
136
138
if use_pitched :
137
139
width_bytes = width_elems * elem_size
138
- pitch_bytes = math .ceil (width_bytes / alignment ) * alignment
140
+ pitch_bytes = int ( np .ceil (width_bytes / alignment ) * alignment )
139
141
pitch_elems = pitch_bytes // elem_size
140
- storage = torch .randint (0 , 255 , (num_ads , pitch_elems ), dtype = dtype , device = device )
142
+ storage = torch .randint (
143
+ 0 , 255 , (num_ads , pitch_elems ), dtype = dtype , device = device
144
+ )
141
145
result_tensor = storage [:, :width_elems ]
142
146
else :
143
- result_tensor = torch .randint (0 , 255 , (num_ads , width_elems ), dtype = dtype , device = device )
147
+ result_tensor = torch .randint (
148
+ 0 , 255 , (num_ads , width_elems ), dtype = dtype , device = device
149
+ )
144
150
145
151
elif data_type == "INT4" :
146
152
assert embedding_dimension % 4 == 0 , "needs to align to 2 bytes for INT4"
@@ -150,12 +156,16 @@ def _get_random_tensor(
150
156
151
157
if use_pitched :
152
158
width_bytes = width_elems * elem_size
153
- pitch_bytes = math .ceil (width_bytes / alignment ) * alignment
159
+ pitch_bytes = int ( np .ceil (width_bytes / alignment ) * alignment )
154
160
pitch_elems = pitch_bytes // elem_size
155
- storage = torch .randint (0 , 255 , (num_ads , pitch_elems ), dtype = dtype , device = device )
161
+ storage = torch .randint (
162
+ 0 , 255 , (num_ads , pitch_elems ), dtype = dtype , device = device
163
+ )
156
164
result_tensor = storage [:, :width_elems ]
157
165
else :
158
- result_tensor = torch .randint (0 , 255 , (num_ads , width_elems ), dtype = dtype , device = device )
166
+ result_tensor = torch .randint (
167
+ 0 , 255 , (num_ads , width_elems ), dtype = dtype , device = device
168
+ )
159
169
160
170
else :
161
171
raise ValueError
0 commit comments