diff --git a/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.py b/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.py index 4561f0a70..8088bf49a 100644 --- a/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.py +++ b/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.py @@ -322,7 +322,6 @@ def create_dynamic_embedding_tables(args, device): score_strategy=DynamicEmbScoreStrategy.LFU if args.cache_algorithm == "lfu" else DynamicEmbScoreStrategy.TIMESTAMP, - caching=args.caching, ) ) diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py index 213c24fbc..2f45dbcbc 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py @@ -528,7 +528,7 @@ def __init__( def _create_cache_storage(self) -> None: self._storages: List[Storage] = [] self._caches: List[Cache] = [] - self._caching = self._dynamicemb_options[0].caching + self._caching = False for option in self._dynamicemb_options: if option.training and option.optimizer_type == OptimizerType.Null: @@ -539,23 +539,36 @@ def _create_cache_storage(self) -> None: "Set OptimizerType to Null as not on training mode.", UserWarning ) - if option.caching and option.training: - cache_option = deepcopy(option) - cache_option.bucket_capacity = 1024 - capacity = get_constraint_capacity( - option.local_hbm_for_values, - option.embedding_dtype, - option.dim, - option.optimizer_type, - cache_option.bucket_capacity, - ) - if capacity == 0: - raise ValueError( - "Can't use caching mode as the reserved HBM size is too small." - ) + value_size = get_value_size( + option.embedding_dtype, + option.dim, + option.optimizer_type, + ) + total_table_bytes = value_size * option.max_capacity + hbm_budget = option.local_hbm_for_values + + if hbm_budget == 0: + # No HBM budget -> storage only, on host + option.local_hbm_for_values = 0 + self._caches.append(None) + self._storages.append(DynamicEmbeddingTable(option, self._optimizer)) + elif total_table_bytes <= hbm_budget: + # Entire table fits in HBM -> single table on GPU serves as + # both cache and storage (no eviction needed). + self._caching = True + table = DynamicEmbeddingTable(option, self._optimizer) + self._caches.append(table) + self._storages.append(table) + else: + # Partial HBM -> cache (GPU) + storage (host or external) + self._caching = True + bucket_capacity_for_cache = 1024 + cache_capacity = hbm_budget // value_size - cache_option.max_capacity = capacity - cache_option.init_capacity = capacity + cache_option = deepcopy(option) + cache_option.bucket_capacity = bucket_capacity_for_cache + cache_option.max_capacity = cache_capacity + cache_option.init_capacity = cache_capacity self._caches.append( DynamicEmbeddingTable(cache_option, self._optimizer) ) @@ -568,9 +581,6 @@ def _create_cache_storage(self) -> None: if PS else DynamicEmbeddingTable(storage_option, self._optimizer) ) - else: - self._caches.append(None) - self._storages.append(DynamicEmbeddingTable(option, self._optimizer)) _print_memory_consume( self._table_names, self._dynamicemb_options, self._optimizer, self.device_id diff --git a/corelib/dynamicemb/dynamicemb/dynamicemb_config.py b/corelib/dynamicemb/dynamicemb/dynamicemb_config.py index 5f12c1b47..c99fc56b3 100644 --- a/corelib/dynamicemb/dynamicemb/dynamicemb_config.py +++ b/corelib/dynamicemb/dynamicemb/dynamicemb_config.py @@ -222,22 +222,6 @@ class DynamicEmbTableOptions(_ContextOptions): For `UNIFORM` and `TRUNCATED_NORMAL`, the `lower` and `upper` will set to $\pm {1 \over \sqrt{EmbeddingConfig.num\_embeddings}}$. eval_initializer_args: DynamicEmbInitializerArgs The initializer args for evaluation mode, and will return torch.zeros(...) as embedding by default if index/sparse feature is missing. - caching: bool - Flag to indicate dynamic embedding tables is working on caching mode, default to `False`. - When the device memory on a single GPU is insufficient to accommodate a single shard of the dynamic embedding table, - dynamicemb supports the mixed use of device memory and host memory(pinned memory). - But by default, the values of the entire table are concatenated with device memory and host memory. - This means that the storage location of one embedding is determined by `hash_function(key)`, and mapping to device memory will bring better lookup performance. - However, sparse features in training are often with temporal locality. - In order to store hot keys in device memory, dynamicemb creates two table instances, - whose values are stored in device memory and host memory respectively, and store hot keys on the GPU table priorily. - If the GPU table is full, the evicted keys will be inserted into the host table. - If the host table is also full, the key will be evicted(all the eviction is based on the score per key). - The original intention of eviction is based on this insight: features that only appear once should not occupy memory(even host memory) for a long time. - In short: - set **`caching=True`** will create a GPU table and a host table, and make GPU table serves as a cache; - set **`caching=False`** will create a hybrid table which use GPU and host memory in a concatenated way to store value. - All keys and other meta data are always stored on GPU for both cases. init_capacity : Optional[int], optional The initial capacity of the table. If not set, it defaults to max_capacity after sharding. If `init_capacity` is provided, it will serve as the initial table capacity on a single GPU. @@ -265,9 +249,9 @@ class DynamicEmbTableOptions(_ContextOptions): Please refer to the API documentation for DynamicEmbCheckMode for more information. global_hbm_for_values : int Total GPU memory allocated to store embedding + optimizer states, in bytes. Default is 0. - It has different meanings under `caching=True` and `caching=False`. - When `caching=False`, it decides how much GPU memory is in the total memory to store value in a single hybrid table. - When `caching=True`, it decides the table capacity of the GPU table. + If the budget can hold the entire table (max_capacity * value_size), the table lives entirely on GPU. + If the budget is nonzero but smaller, it determines the GPU cache capacity while the full table is stored on host/external storage. + If zero, the table is stored entirely on host memory. external_storage: Storage The external storage/ParamterServer which inherits the interface of Storage, and can be configured per table. If not provided, will using DynamicEmbeddingTable as the Storage. @@ -297,7 +281,6 @@ class DynamicEmbTableOptions(_ContextOptions): value=0.0, ) ) - caching: bool = False init_capacity: Optional[ int ] = None # if not set then set to max_capcacity after sharded @@ -339,7 +322,6 @@ def __ne__(self, other): def get_grouped_key(self): grouped_key = {} grouped_key["training"] = self.training - grouped_key["caching"] = self.caching grouped_key["external_storage"] = self.external_storage grouped_key["index_type"] = self.index_type grouped_key["score_strategy"] = self.score_strategy @@ -498,6 +480,30 @@ def validate_initializer_args( initializer_args.upper = default_upper +def get_comsued_bytes_of_table( + max_capacity, + dtype, + dim, + optimizer_type, +) -> int: + byte_consume_per_vector = ( + dim + get_optimizer_state_dim(optimizer_type, dim, dtype) + ) * dtype_to_bytes(dtype) + total_consumed = max_capacity * byte_consume_per_vector + return total_consumed + + +def get_value_size( + dtype, + dim, + optimizer_type, +) -> int: + byte_consume_per_vector = ( + dim + get_optimizer_state_dim(optimizer_type, dim, dtype) + ) * dtype_to_bytes(dtype) + return byte_consume_per_vector + + def get_constraint_capacity( memory_bytes, dtype, diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py index 8f0fc4984..8bc6a112c 100644 --- a/corelib/dynamicemb/dynamicemb/key_value_table.py +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -50,11 +50,11 @@ from dynamicemb_extensions import ( EvictStrategy, device_timestamp, - load_from_combined_table, + load_from_table, select, select_index, select_insert_failed_values, - store_to_combined_table, + store_to_table, ) from torch import Tensor, nn # usort:skip from torchrec import JaggedTensor @@ -118,10 +118,6 @@ def __init__( device_idx = torch.cuda.current_device() self.device = torch.device(f"cuda:{device_idx}") - # assert ( - # options.init_capacity == options.max_capacity - # ), "Capacity growth is appending..." - self.score_policy = get_score_policy(options.score_strategy) self.evict_strategy_ = options.evict_strategy.value @@ -134,50 +130,24 @@ def __init__( ) self._capacity = self.key_index_map.capacity() - # TODO: maybe we can separate it in the future like fbgemm - # self.embedding_weights_dev = None - # self.embedding_weights_uvm = None - # self.optimizer_states_dev = None - # self.optimizer_states_uvm = None - self.dev_table = None - self.uvm_table = None - self._emb_dtype = self.options.embedding_dtype self._emb_dim = self.options.dim self._optim_states_dim = optimizer.get_state_dim(self._emb_dim) self._value_dim = self._emb_dim + self._optim_states_dim - total_memory_need = ( - self._capacity * self._value_dim * dtype_to_bytes(self._emb_dtype) - ) - if options.local_hbm_for_values == 0: - # weight_uvm_size = self._capacity * self._emb_dim - # optim_states_uvm_size = self._capacity * self._optim_states_dim - # self.embedding_weights_uvm = get_uvm_tensor(weight_uvm_size, self._emb_dtype, self.device) - # self.embedding_weights_uvm = get_uvm_tensor(optim_states_uvm_size, self._emb_dtype, self.device) - uvm_size = self._capacity * self._value_dim - self.uvm_table = get_uvm_tensor( - uvm_size, dtype=self._emb_dtype, device=self.device - ).view(-1, self._value_dim) - elif options.local_hbm_for_values >= total_memory_need: - dev_size = self._capacity * self._value_dim - self.dev_table = torch.empty( - dev_size, dtype=self._emb_dtype, device=self.device + self._on_device = options.local_hbm_for_values != 0 + num_elements = self._capacity * self._value_dim + if self._on_device: + self._table = torch.empty( + num_elements, dtype=self._emb_dtype, device=self.device ).view(-1, self._value_dim) else: - # hybrid mode - dev_size = ( - options.local_hbm_for_values - // (self._value_dim * dtype_to_bytes(self._emb_dtype)) - * self._value_dim - ) - uvm_size = self._capacity * self._value_dim - dev_size - self.dev_table = torch.empty( - dev_size, dtype=self._emb_dtype, device=self.device - ).view(-1, self._value_dim) - self.uvm_table = get_uvm_tensor( - uvm_size, dtype=self._emb_dtype, device=self.device + self._table = get_uvm_tensor( + num_elements, dtype=self._emb_dtype, device=self.device ).view(-1, self._value_dim) + # TODO: maybe we can separate it in the future like fbgemm + # self.embedding_weights = None + # self.optimizer_states = None self.score: int = None self._score_update = False @@ -216,32 +186,15 @@ def expand(self): device=self.device, ) - dev_table = None - uvm_table = None - - if self.options.local_hbm_for_values == 0: - uvm_size = target_capacity * self._value_dim - uvm_table = get_uvm_tensor( - uvm_size, dtype=self._emb_dtype, device=self.device - ).view(-1, self._value_dim) - elif self.options.local_hbm_for_values >= total_memory_need: - dev_size = target_capacity * self._value_dim - dev_table = torch.empty( - dev_size, dtype=self._emb_dtype, device=self.device + table = None + num_elements = target_capacity * self._value_dim + if self._on_device: + table = torch.empty( + num_elements, dtype=self._emb_dtype, device=self.device ).view(-1, self._value_dim) else: - # hybrid mode - dev_size = ( - self.options.local_hbm_for_values - // (self._value_dim * dtype_to_bytes(self._emb_dtype)) - * self._value_dim - ) - uvm_size = target_capacity * self._value_dim - dev_size - dev_table = torch.empty( - dev_size, dtype=self._emb_dtype, device=self.device - ).view(-1, self._value_dim) - uvm_table = get_uvm_tensor( - uvm_size, dtype=self._emb_dtype, device=self.device + table = get_uvm_tensor( + num_elements, dtype=self._emb_dtype, device=self.device ).view(-1, self._value_dim) for ( @@ -260,7 +213,7 @@ def expand(self): dtype=self.value_type(), device=self.device, ) - load_from_combined_table(self.dev_table, self.uvm_table, indices, values) + load_from_table(self._table, indices, values) # when load into the table, we always assign the scores from the file. score_args_insert = [ @@ -276,16 +229,14 @@ def expand(self): ) key_index_map.insert(keys, score_args_insert, indices) - store_to_combined_table( - dev_table, - uvm_table, + store_to_table( + table, indices, values, ) self.key_index_map = key_index_map - self.dev_table = dev_table - self.uvm_table = uvm_table + self._table = table self._capacity = target_capacity def find_impl( @@ -342,9 +293,7 @@ def find_impl( self.key_index_map.lookup(unique_keys, score_args_lookup, founds, indices) if load_dim != 0: - load_from_combined_table( - self.dev_table, self.uvm_table, indices, unique_embs - ) + load_from_table(self._table, indices, unique_embs) missing = torch.logical_not(founds) num_missing_0: torch.Tensor = torch.empty(1, dtype=torch.long, device=device) @@ -484,9 +433,7 @@ def insert( ) self.key_index_map.insert(unique_keys, score_args_insert, indices) - store_to_combined_table( - self.dev_table, self.uvm_table, indices, unique_values.to(self.value_type()) - ) + store_to_table(self._table, indices, unique_values.to(self.value_type())) def update( self, @@ -517,7 +464,9 @@ def update( self.key_index_map.lookup(keys, score_args_lookup, founds, indices) self.optimizer.fused_update_with_index( - grads.to(self.value_type()), indices, self.dev_table, self.uvm_table + grads.to(self.value_type()), + indices, + self._table, ) if return_missing: @@ -841,9 +790,8 @@ def load_key_values( ) self.key_index_map.insert(keys, score_args_insert, indices) - store_to_combined_table( - self.dev_table, - self.uvm_table, + store_to_table( + self._table, indices, values.to(value_type), ) @@ -874,7 +822,7 @@ def export_keys_values( dtype=self.value_type(), device=self.device, ) - load_from_combined_table(self.dev_table, self.uvm_table, indices, values) + load_from_table(self._table, indices, values) embeddings = ( values[:, : self._emb_dim].to(dtype=EMBEDDING_TYPE).contiguous() ) @@ -958,15 +906,13 @@ def insert_and_evict( select_insert_failed_values(evicted_indices, values, evicted_values) - load_from_combined_table( - self.dev_table, - self.uvm_table, + load_from_table( + self._table, evicted_indices, evicted_values, ) - store_to_combined_table( - self.dev_table, self.uvm_table, indices, values.to(self.value_type()) - ) + + store_to_table(self._table, indices, values.to(self.value_type())) if self._record_cache_metrics: self._cache_metrics[2] = batch @@ -999,7 +945,7 @@ def flush(self, storage: Storage) -> None: device=self.device, ) - load_from_combined_table(self.dev_table, self.uvm_table, indices, values) + load_from_table(self._table, indices, values) storage.insert(keys, values, scores) def reset( @@ -1055,9 +1001,7 @@ def incremental_dump( keys.numel(), self._emb_dim, dtype=self._emb_dtype, device=self.device ) if keys.numel() != 0: - load_from_combined_table( - self.dev_table, self.uvm_table, indices.to(self.device), embs - ) + load_from_table(self._table, indices.to(self.device), embs) if not dist.is_initialized() or dist.get_world_size(group=pg) == 1: return keys, embs.cpu() @@ -1218,9 +1162,7 @@ def deterministic_insert( h_num_unique_keys, dtype=torch.bool, device=unique_keys.device ) self.key_index_map.lookup(unique_keys, score_args_lookup, founds, indices) - store_to_combined_table( - self.dev_table, self.uvm_table, indices, unique_values.to(self.value_type()) - ) + store_to_table(self._table, indices, unique_values.to(self.value_type())) return def deterministic_insert_and_evict( @@ -1324,9 +1266,8 @@ def deterministic_insert_and_evict( assert len(set(evicted_indices_accum.tolist())) == num_evicted_accum - load_from_combined_table( - self.dev_table, - self.uvm_table, + load_from_table( + self._table, evicted_indices_accum, evicted_values_accum, ) @@ -1348,9 +1289,7 @@ def deterministic_insert_and_evict( h_num_unique_keys, dtype=torch.bool, device=unique_keys.device ) self.key_index_map.lookup(unique_keys, score_args_lookup, founds, indices) - store_to_combined_table( - self.dev_table, self.uvm_table, indices, unique_values.to(self.value_type()) - ) + store_to_table(self._table, indices, unique_values.to(self.value_type())) if self._record_cache_metrics: self._cache_metrics[2] = h_num_unique_keys diff --git a/corelib/dynamicemb/dynamicemb/optimizer.py b/corelib/dynamicemb/dynamicemb/optimizer.py index 292969cb9..4e7fa955a 100644 --- a/corelib/dynamicemb/dynamicemb/optimizer.py +++ b/corelib/dynamicemb/dynamicemb/optimizer.py @@ -22,10 +22,10 @@ import torch # usort:skip from dynamicemb_extensions import ( OptimizerType, - adagrad_update_for_combined_table, - adam_update_for_combined_table, - rowwise_adagrad_for_combined_table, - sgd_update_for_combined_table, + adagrad_update_for_table, + adam_update_for_table, + rowwise_adagrad_for_table, + sgd_update_for_table, ) @@ -124,8 +124,7 @@ def fused_update_with_index( self, grads: torch.Tensor, indices: torch.Tensor, - dev_table: torch.Tensor, - uvm_table: torch.Tensor, + table: torch.Tensor, ) -> None: ... @@ -192,15 +191,13 @@ def fused_update_with_index( self, grads: torch.Tensor, indices: torch.Tensor, - dev_table: torch.Tensor, - uvm_table: torch.Tensor, + table: torch.Tensor, ) -> None: lr = self._opt_args.learning_rate - sgd_update_for_combined_table( + sgd_update_for_table( grads, indices, - dev_table, - uvm_table, + table, lr, ) @@ -237,8 +234,7 @@ def fused_update_with_index( self, grads: torch.Tensor, indices: torch.Tensor, - dev_table: torch.Tensor, - uvm_table: torch.Tensor, + table: torch.Tensor, ) -> None: lr = self._opt_args.learning_rate beta1 = self._opt_args.beta1 @@ -249,11 +245,10 @@ def fused_update_with_index( emb_dim = grads.size(1) state_dim = self.get_state_dim(emb_dim) - adam_update_for_combined_table( + adam_update_for_table( grads, indices, - dev_table, - uvm_table, + table, state_dim, lr, beta1, @@ -302,8 +297,7 @@ def fused_update_with_index( self, grads: torch.Tensor, indices: torch.Tensor, - dev_table: torch.Tensor, - uvm_table: torch.Tensor, + table: torch.Tensor, ) -> None: lr = self._opt_args.learning_rate eps = self._opt_args.eps @@ -311,11 +305,10 @@ def fused_update_with_index( emb_dim = grads.size(1) state_dim = self.get_state_dim(emb_dim) - adagrad_update_for_combined_table( + adagrad_update_for_table( grads, indices, - dev_table, - uvm_table, + table, state_dim, lr, eps, @@ -363,8 +356,7 @@ def fused_update_with_index( self, grads: torch.Tensor, indices: torch.Tensor, - dev_table: torch.Tensor, - uvm_table: torch.Tensor, + table: torch.Tensor, ) -> None: lr = self._opt_args.learning_rate eps = self._opt_args.eps @@ -372,11 +364,10 @@ def fused_update_with_index( emb_dim = grads.size(1) state_dim = self.get_state_dim(emb_dim) - rowwise_adagrad_for_combined_table( + rowwise_adagrad_for_table( grads, indices, - dev_table, - uvm_table, + table, state_dim, lr, eps, diff --git a/corelib/dynamicemb/example/example.py b/corelib/dynamicemb/example/example.py index 8426afe8e..cbd72300e 100644 --- a/corelib/dynamicemb/example/example.py +++ b/corelib/dynamicemb/example/example.py @@ -521,7 +521,6 @@ def get_planner( mode=DynamicEmbInitializerMode.NORMAL ), score_strategy=DynamicEmbScoreStrategy.STEP, - caching=caching, training=training, admit_strategy=admit_strategy, admission_counter=admission_counter, @@ -596,7 +595,7 @@ def apply_dmp(model, args, training): args.batch_size, optimizer_type=optimizer_type, training=training, - caching=args.caching, + caching=args.caching, # used for HBM budget calculation, not passed to options args=args, ) # get plan for all ranks. diff --git a/corelib/dynamicemb/src/dynamic_emb_op.cu b/corelib/dynamicemb/src/dynamic_emb_op.cu index 31b974c55..6f391e02d 100644 --- a/corelib/dynamicemb/src/dynamic_emb_op.cu +++ b/corelib/dynamicemb/src/dynamic_emb_op.cu @@ -282,11 +282,9 @@ reduce_grads(at::Tensor reverse_indices, at::Tensor grads, int64_t num_unique, } template -__global__ void load_from_combined_table_kernel_vec4( - int64_t batch, int emb_dim, int stride, int64_t output_stride, - int split_index, ValueT const *__restrict__ dev_table, - ValueT const *__restrict__ uvm_table, ValueT *__restrict__ output_buffer, - IndexT const *__restrict__ indices) { +__global__ void load_from_table_kernel_vec4( + int64_t batch, int emb_dim, int stride, ValueT const *__restrict__ table, + ValueT *__restrict__ output_buffer, IndexT const *__restrict__ indices) { constexpr int kWarpSize = 32; constexpr int VecSize = 4; @@ -298,87 +296,61 @@ __global__ void load_from_combined_table_kernel_vec4( for (int64_t emb_id = warp_num_per_block * blockIdx.x + warp_id_in_block; emb_id < batch; emb_id += gridDim.x * warp_num_per_block) { IndexT const index = indices[emb_id]; - ValueT const *src = nullptr; - if (index < split_index) { - src = dev_table + index * stride; - } else { - src = uvm_table + (index - split_index) * stride; - } - ValueT *dst = output_buffer + emb_id * output_stride; - if (index >= 0) { - for (int i = 0; VecSize * (kWarpSize * i + lane_id) < emb_dim; ++i) { - int idx4 = VecSize * (kWarpSize * i + lane_id); - emb.load(src + idx4); - emb.store(dst + idx4); - } + + if (index < 0) + continue; + ValueT const *src = table + index * stride; + ValueT *dst = output_buffer + emb_id * emb_dim; + + for (int i = 0; VecSize * (kWarpSize * i + lane_id) < emb_dim; ++i) { + int idx4 = VecSize * (kWarpSize * i + lane_id); + emb.load(src + idx4); + emb.store(dst + idx4); } } } template -__global__ void load_from_combined_table_kernel( - int64_t batch, int emb_dim, int stride, int64_t output_stride, - int split_index, ValueT const *__restrict__ dev_table, - ValueT const *__restrict__ uvm_table, ValueT *__restrict__ output_buffer, - IndexT const *__restrict__ indices) { +__global__ void load_from_table_kernel(int64_t batch, int emb_dim, int stride, + ValueT const *__restrict__ table, + ValueT *__restrict__ output_buffer, + IndexT const *__restrict__ indices) { for (int64_t emb_id = blockIdx.x; emb_id < batch; emb_id += gridDim.x) { IndexT const index = indices[emb_id]; - ValueT const *src = nullptr; - if (index < split_index) { - src = dev_table + index * stride; - } else { - src = uvm_table + (index - split_index) * stride; - } - ValueT *dst = output_buffer + emb_id * output_stride; - if (index >= 0) { - for (int i = threadIdx.x; i < emb_dim; i += blockDim.x) { - dst[i] = src[i]; - } + if (index < 0) + continue; + ValueT const *src = table + index * stride; + ValueT *dst = output_buffer + emb_id * emb_dim; + + for (int i = threadIdx.x; i < emb_dim; i += blockDim.x) { + dst[i] = src[i]; } } } -void load_from_combined_table(std::optional dev_table, - std::optional uvm_table, - at::Tensor indices, at::Tensor output) { +void load_from_table(at::Tensor table, at::Tensor indices, at::Tensor output) { int64_t num_total = indices.size(0); - if (num_total == 0) { - return; - } - int64_t stride = -1; int64_t dim = output.size(1); - int64_t output_stride = output.stride(0); - if ((not dev_table.has_value()) and (not uvm_table.has_value())) { - throw std::runtime_error("Two tables cannot both be None."); - } else { - if (dev_table.has_value()) { - stride = dev_table.value().size(1); - if (stride < dim) { - throw std::runtime_error( - "Output tensor's dim1 should not be greater than the table's."); - } - } else { - stride = uvm_table.value().size(1); - if (stride < dim) { - throw std::runtime_error( - "Output tensor's dim1 should not be greater than the table's."); - } - } - } - if (output.dim() != 2) { - throw std::runtime_error("Output tensor should be 2-dim."); + if (output.dim() != 2 and table.dim() != 2) { + throw std::runtime_error("Output and table tensor should be 2-dim."); } if (output.size(0) != indices.size(0)) { throw std::runtime_error("Output tensor mismatches with indices at dim-0."); } - int64_t split_index = 0; - if (dev_table.has_value()) { - split_index = dev_table.value().size(0); + int64_t stride = table.size(1); + + if (dim > stride) { + throw std::runtime_error( + "Output tensor's dim1 should not be greater than the table's."); + } + + if (num_total == 0) { + return; } auto val_type = get_data_type(output); @@ -405,25 +377,22 @@ void load_from_combined_table(std::optional dev_table, DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, ValueType, [&] { DISPATCH_OFFSET_INT_TYPE(index_type, IndexType, [&] { - auto dev_ptr = get_pointer(dev_table); - auto uvm_ptr = get_pointer(uvm_table); + auto table_ptr = get_pointer(table); auto out_ptr = get_pointer(output); auto index_ptr = get_pointer(indices); if (dim % 4 == 0) { - load_from_combined_table_kernel_vec4 + load_from_table_kernel_vec4 <<>>( - num_total, dim, stride, output_stride, split_index, dev_ptr, - uvm_ptr, out_ptr, index_ptr); + num_total, dim, stride, table_ptr, out_ptr, index_ptr); } else { int block_size = dim < device_prop.max_thread_per_block ? dim : device_prop.max_thread_per_block; int grid_size = num_total; - load_from_combined_table_kernel + load_from_table_kernel <<>>( - num_total, dim, stride, output_stride, split_index, dev_ptr, - uvm_ptr, out_ptr, index_ptr); + num_total, dim, stride, table_ptr, out_ptr, index_ptr); } }); }); @@ -431,10 +400,11 @@ void load_from_combined_table(std::optional dev_table, } template -__global__ void store_to_combined_table_kernel_vec4( - int64_t batch, int stride, int split_index, ValueT *__restrict__ dev_table, - ValueT *__restrict__ uvm_table, ValueT const *__restrict__ input_buffer, - IndexT const *__restrict__ indices) { +__global__ void +store_to_table_kernel_vec4(int64_t batch, int stride, + ValueT *__restrict__ table, + ValueT const *__restrict__ input_buffer, + IndexT const *__restrict__ indices) { constexpr int kWarpSize = 32; constexpr int VecSize = 4; @@ -446,88 +416,68 @@ __global__ void store_to_combined_table_kernel_vec4( for (int64_t emb_id = warp_num_per_block * blockIdx.x + warp_id_in_block; emb_id < batch; emb_id += gridDim.x * warp_num_per_block) { IndexT const index = indices[emb_id]; - ValueT *dst = nullptr; - if (index < split_index) { - dst = dev_table + index * stride; - } else { - dst = uvm_table + (index - split_index) * stride; - } + + if (index < 0) + continue; + ValueT *dst = table + index * stride; ValueT const *src = input_buffer + emb_id * stride; - if (index >= 0) { - for (int i = 0; VecSize * (kWarpSize * i + lane_id) < stride; ++i) { - int idx4 = VecSize * (kWarpSize * i + lane_id); - emb.load(src + idx4); - emb.store(dst + idx4); - } + + for (int i = 0; VecSize * (kWarpSize * i + lane_id) < stride; ++i) { + int idx4 = VecSize * (kWarpSize * i + lane_id); + emb.load(src + idx4); + emb.store(dst + idx4); } } } template -__global__ void store_to_combined_table_kernel( - int64_t batch, int stride, int split_index, ValueT *__restrict__ dev_table, - ValueT *__restrict__ uvm_table, ValueT const *__restrict__ input_buffer, - IndexT const *__restrict__ indices) { +__global__ void store_to_table_kernel(int64_t batch, int stride, + ValueT *__restrict__ table, + ValueT const *__restrict__ input_buffer, + IndexT const *__restrict__ indices) { for (int64_t emb_id = blockIdx.x; emb_id < batch; emb_id += gridDim.x) { IndexT const index = indices[emb_id]; - ValueT *dst = nullptr; - if (index < split_index) { - dst = dev_table + index * stride; - } else { - dst = uvm_table + (index - split_index) * stride; - } + if (index < 0) + continue; + ValueT *dst = table + index * stride; ValueT const *src = input_buffer + emb_id * stride; - if (index >= 0) { - for (int i = threadIdx.x; i < stride; i += blockDim.x) { - dst[i] = src[i]; - } + + for (int i = threadIdx.x; i < stride; i += blockDim.x) { + dst[i] = src[i]; } } } -void store_to_combined_table(std::optional dev_table, - std::optional uvm_table, - at::Tensor indices, at::Tensor input) { +void store_to_table(at::Tensor table, + + at::Tensor indices, at::Tensor input) { - int64_t stride = -1; + int64_t num_total = indices.size(0); int64_t dim = input.size(1); - if ((not dev_table.has_value()) and (not uvm_table.has_value())) { - throw std::runtime_error("Two tables cannot both be None."); - } else { - if (dev_table.has_value()) { - stride = dev_table.value().size(1); - if (stride != dim) { - throw std::runtime_error( - "Input tensor's dim1 should equal to the table's."); - } - } else { - stride = uvm_table.value().size(1); - if (stride != dim) { - throw std::runtime_error( - "Input tensor's dim1 should equal to the table's."); - } - } - } - if (input.dim() != 2) { - throw std::runtime_error("Input tensor should be 2-dim."); + if (input.dim() != 2 and table.dim() != 2) { + throw std::runtime_error("input and table tensor should be 2-dim."); } if (input.size(0) != indices.size(0)) { - throw std::runtime_error("Input tensor mismatches with indices at dim-0."); + throw std::runtime_error("input tensor mismatches with indices at dim-0."); } - int64_t split_index = 0; - if (dev_table.has_value()) { - split_index = dev_table.value().size(0); + int64_t stride = table.size(1); + + if (dim != stride) { + throw std::runtime_error( + "Input tensor's dim1 should equal to the table's."); + } + + if (num_total == 0) { + return; } auto val_type = get_data_type(input); auto index_type = get_data_type(indices); - int64_t num_total = indices.size(0); - constexpr int kWarpSize = 32; constexpr int MULTIPLIER = 4; constexpr int BLOCK_SIZE_VEC = 64; @@ -549,25 +499,22 @@ void store_to_combined_table(std::optional dev_table, DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, ValueType, [&] { DISPATCH_OFFSET_INT_TYPE(index_type, IndexType, [&] { - auto dev_ptr = get_pointer(dev_table); - auto uvm_ptr = get_pointer(uvm_table); + auto table_ptr = get_pointer(table); auto input_ptr = get_pointer(input); auto index_ptr = get_pointer(indices); if (dim % 4 == 0) { - store_to_combined_table_kernel_vec4 + store_to_table_kernel_vec4 <<>>( - num_total, stride, split_index, dev_ptr, uvm_ptr, input_ptr, - index_ptr); + num_total, stride, table_ptr, input_ptr, index_ptr); } else { int block_size = dim < device_prop.max_thread_per_block ? dim : device_prop.max_thread_per_block; int grid_size = num_total; - store_to_combined_table_kernel - <<>>( - num_total, stride, split_index, dev_ptr, uvm_ptr, input_ptr, - index_ptr); + store_to_table_kernel + <<>>(num_total, stride, table_ptr, + input_ptr, index_ptr); } }); }); @@ -728,20 +675,18 @@ void bind_dyn_emb_op(py::module &m) { m.def("gather_embedding", &gather_embedding, "Gather embedding based on index.", py::arg("input"), py::arg("output"), py::arg("index")); - + m.def("gather_embedding_pooled", &gather_embedding_pooled, - "Gather embedding with pooling (SUM/MEAN) based on index and offsets.", - py::arg("input"), py::arg("output"), py::arg("index"), - py::arg("offsets"), py::arg("combiner"), py::arg("total_D"), - py::arg("batch_size"), py::arg("D_offsets") = py::none(), - py::arg("max_D") = 0); - - m.def("load_from_combined_table", &load_from_combined_table, - "load_from_combined_table", py::arg("dev_table"), py::arg("uvm_table"), - py::arg("indices"), py::arg("output")); - - m.def("store_to_combined_table", &store_to_combined_table, - "store_to_combined_table", py::arg("dev_table"), py::arg("uvm_table"), + "Gather embedding with pooling (SUM/MEAN) based on index and offsets.", + py::arg("input"), py::arg("output"), py::arg("index"), + py::arg("offsets"), py::arg("combiner"), py::arg("total_D"), + py::arg("batch_size"), py::arg("D_offsets") = py::none(), + py::arg("max_D") = 0); + + m.def("load_from_table", &load_from_table, "load_from_table", + py::arg("table"), py::arg("indices"), py::arg("output")); + + m.def("store_to_table", &store_to_table, "store_to_table", py::arg("table"), py::arg("indices"), py::arg("input")); m.def("select_insert_failed_values", &select_insert_failed_values, diff --git a/corelib/dynamicemb/src/optimizer.cu b/corelib/dynamicemb/src/optimizer.cu index cdbe70c34..49ed076ed 100644 --- a/corelib/dynamicemb/src/optimizer.cu +++ b/corelib/dynamicemb/src/optimizer.cu @@ -31,12 +31,10 @@ constexpr int OPTIMIZER_BLOCKSIZE = 1024; template -void launch_update_kernel_for_combined_table( - GradType *grads, WeightType *dev_table, WeightType *uvm_table, - IndexType *indices, OptimizerType opt, int64_t const ev_nums, - uint32_t const dim, int64_t const stride, int64_t const split_index, - int device_id, - std::function smem_size_f = [](int block_size) { return 0; }) { +void launch_update_kernel_for_table(GradType *grads, WeightType *table, + IndexType *indices, OptimizerType opt, + int64_t const ev_nums, uint32_t const dim, + int64_t const stride, int device_id) { auto stream = at::cuda::getCurrentCUDAStream().stream(); auto &device_prop = DeviceProp::getDeviceProp(device_id); if (dim % 4 == 0) { @@ -57,25 +55,21 @@ void launch_update_kernel_for_combined_table( auto kernel = update4_with_index_kernel; kernel<<>>( - ev_nums, dim, stride, split_index, grads, dev_table, uvm_table, indices, - nullptr, opt); + ev_nums, dim, stride, grads, table, indices, nullptr, opt); } else { int block_size = dim > OPTIMIZER_BLOCKSIZE ? OPTIMIZER_BLOCKSIZE : dim; int grid_size = ev_nums; auto kernel = update_with_index_kernel; - kernel<<>>( - ev_nums, dim, stride, split_index, grads, dev_table, uvm_table, indices, - nullptr, opt); + kernel<<>>(ev_nums, dim, stride, grads, + table, indices, nullptr, opt); } DEMB_CUDA_KERNEL_LAUNCH_CHECK(); } -void sgd_update_for_combined_table(at::Tensor grads, at::Tensor indices, - std::optional dev_table, - std::optional uvm_table, - float const lr) { +void sgd_update_for_table(at::Tensor grads, at::Tensor indices, + at::Tensor table, float const lr) { int64_t ev_nums = grads.size(0); int64_t dim = grads.size(1); if (ev_nums == 0) @@ -83,56 +77,37 @@ void sgd_update_for_combined_table(at::Tensor grads, at::Tensor indices, TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); TORCH_CHECK(grads.is_cuda(), "grads must be a CUDA tensor"); - DataType val_type; - int64_t stride = -1; - if ((not dev_table.has_value()) and (not uvm_table.has_value())) { - throw std::runtime_error("Two tables cannot both be None."); - } else { - if (dev_table.has_value()) { - val_type = get_data_type(dev_table.value()); - stride = dev_table.value().size(1); - } else { - val_type = get_data_type(uvm_table.value()); - stride = uvm_table.value().size(1); - } - } + DataType val_type = get_data_type(table); + int64_t stride = table.size(1); auto grad_type = get_data_type(grads); auto index_type = get_data_type(indices); - int64_t split_index = 0; - if (dev_table.has_value()) { - split_index = dev_table.value().size(0); - } - int device_id = grads.device().index(); DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { DISPATCH_OFFSET_INT_TYPE(index_type, i_t, [&] { auto grad_ptr = get_pointer(grads); - auto dev_ptr = get_pointer(dev_table); - auto uvm_ptr = get_pointer(uvm_table); + auto table_ptr = get_pointer(table); auto index_ptr = get_pointer(indices); SgdVecOptimizer opt{lr}; - launch_update_kernel_for_combined_table( - grad_ptr, dev_ptr, uvm_ptr, index_ptr, opt, ev_nums, dim, stride, - split_index, device_id); + launch_update_kernel_for_table( + grad_ptr, table_ptr, index_ptr, opt, ev_nums, dim, stride, + device_id); }); }); }); } -void adam_update_for_combined_table(at::Tensor grads, at::Tensor indices, - std::optional dev_table, - std::optional uvm_table, - int64_t state_dim, const float lr, - const float beta1, const float beta2, - const float eps, const float weight_decay, - const uint32_t iter_num) { +void adam_update_for_table(at::Tensor grads, at::Tensor indices, + at::Tensor table, int64_t state_dim, const float lr, + const float beta1, const float beta2, + const float eps, const float weight_decay, + const uint32_t iter_num) { int64_t ev_nums = grads.size(0); int64_t dim = grads.size(1); if (ev_nums == 0) @@ -140,54 +115,36 @@ void adam_update_for_combined_table(at::Tensor grads, at::Tensor indices, TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); TORCH_CHECK(grads.is_cuda(), "grads must be a CUDA tensor"); - DataType val_type; - int64_t stride = -1; - if ((not dev_table.has_value()) and (not uvm_table.has_value())) { - throw std::runtime_error("Two tables cannot both be None."); - } else { - if (dev_table.has_value()) { - val_type = get_data_type(dev_table.value()); - stride = dev_table.value().size(1); - } else { - val_type = get_data_type(uvm_table.value()); - stride = uvm_table.value().size(1); - } - } + DataType val_type = get_data_type(table); + int64_t stride = table.size(1); auto grad_type = get_data_type(grads); auto index_type = get_data_type(indices); - int64_t split_index = 0; - if (dev_table.has_value()) { - split_index = dev_table.value().size(0); - } int device_id = grads.device().index(); DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { DISPATCH_OFFSET_INT_TYPE(index_type, i_t, [&] { auto grad_ptr = get_pointer(grads); - auto dev_ptr = get_pointer(dev_table); - auto uvm_ptr = get_pointer(uvm_table); + auto table_ptr = get_pointer(table); auto index_ptr = get_pointer(indices); AdamVecOptimizer opt{lr, beta1, beta2, eps, weight_decay, iter_num}; - launch_update_kernel_for_combined_table( - grad_ptr, dev_ptr, uvm_ptr, index_ptr, opt, ev_nums, dim, stride, - split_index, device_id); + launch_update_kernel_for_table( + grad_ptr, table_ptr, index_ptr, opt, ev_nums, dim, stride, + device_id); }); }); }); } -void adagrad_update_for_combined_table(at::Tensor grads, at::Tensor indices, - std::optional dev_table, - std::optional uvm_table, - int64_t state_dim, const float lr, - const float eps) { +void adagrad_update_for_table(at::Tensor grads, at::Tensor indices, + at::Tensor table, int64_t state_dim, + const float lr, const float eps) { int64_t ev_nums = grads.size(0); int64_t dim = grads.size(1); @@ -196,53 +153,34 @@ void adagrad_update_for_combined_table(at::Tensor grads, at::Tensor indices, TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); TORCH_CHECK(grads.is_cuda(), "grads must be a CUDA tensor"); - DataType val_type; - int64_t stride = -1; - if ((not dev_table.has_value()) and (not uvm_table.has_value())) { - throw std::runtime_error("Two tables cannot both be None."); - } else { - if (dev_table.has_value()) { - val_type = get_data_type(dev_table.value()); - stride = dev_table.value().size(1); - } else { - val_type = get_data_type(uvm_table.value()); - stride = uvm_table.value().size(1); - } - } + DataType val_type = get_data_type(table); + int64_t stride = table.size(1); auto grad_type = get_data_type(grads); auto index_type = get_data_type(indices); - int64_t split_index = 0; - if (dev_table.has_value()) { - split_index = dev_table.value().size(0); - } - int device_id = grads.device().index(); DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { DISPATCH_OFFSET_INT_TYPE(index_type, i_t, [&] { auto grad_ptr = get_pointer(grads); - auto dev_ptr = get_pointer(dev_table); - auto uvm_ptr = get_pointer(uvm_table); + auto table_ptr = get_pointer(table); auto index_ptr = get_pointer(indices); AdaGradVecOptimizer opt{lr, eps}; - launch_update_kernel_for_combined_table( - grad_ptr, dev_ptr, uvm_ptr, index_ptr, opt, ev_nums, dim, stride, - split_index, device_id); + launch_update_kernel_for_table( + grad_ptr, table_ptr, index_ptr, opt, ev_nums, dim, stride, + device_id); }); }); }); } -void rowwise_adagrad_for_combined_table(at::Tensor grads, at::Tensor indices, - std::optional dev_table, - std::optional uvm_table, - int64_t state_dim, const float lr, - const float eps) { +void rowwise_adagrad_for_table(at::Tensor grads, at::Tensor indices, + at::Tensor table, int64_t state_dim, + const float lr, const float eps) { int64_t ev_nums = grads.size(0); int64_t dim = grads.size(1); @@ -251,45 +189,27 @@ void rowwise_adagrad_for_combined_table(at::Tensor grads, at::Tensor indices, TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); TORCH_CHECK(grads.is_cuda(), "grads must be a CUDA tensor"); - DataType val_type; - int64_t stride = -1; - if ((not dev_table.has_value()) and (not uvm_table.has_value())) { - throw std::runtime_error("Two tables cannot both be None."); - } else { - if (dev_table.has_value()) { - val_type = get_data_type(dev_table.value()); - stride = dev_table.value().size(1); - } else { - val_type = get_data_type(uvm_table.value()); - stride = uvm_table.value().size(1); - } - } + DataType val_type = get_data_type(table); + int64_t stride = table.size(1); auto grad_type = get_data_type(grads); auto index_type = get_data_type(indices); - int64_t split_index = 0; - if (dev_table.has_value()) { - split_index = dev_table.value().size(0); - } - int device_id = grads.device().index(); DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { DISPATCH_OFFSET_INT_TYPE(index_type, i_t, [&] { auto grad_ptr = get_pointer(grads); - auto dev_ptr = get_pointer(dev_table); - auto uvm_ptr = get_pointer(uvm_table); + auto table_ptr = get_pointer(table); auto index_ptr = get_pointer(indices); RowWiseAdaGradVecOptimizer opt{lr, eps}; - launch_update_kernel_for_combined_table( - grad_ptr, dev_ptr, uvm_ptr, index_ptr, opt, ev_nums, dim, stride, - split_index, device_id, - [](int block_size) { return block_size * sizeof(float); }); + launch_update_kernel_for_table( + grad_ptr, table_ptr, index_ptr, opt, ev_nums, dim, stride, + device_id); }); }); }); @@ -299,27 +219,23 @@ void rowwise_adagrad_for_combined_table(at::Tensor grads, at::Tensor indices, // PYTHON WRAP void bind_optimizer_kernel_op(py::module &m) { - m.def("sgd_update_for_combined_table", - &dyn_emb::sgd_update_for_combined_table, + m.def("sgd_update_for_table", &dyn_emb::sgd_update_for_table, "SGD optimizer for Dynamic Emb", py::arg("grads"), py::arg("indices"), - py::arg("dev_table"), py::arg("uvm_table"), py::arg("lr")); + py::arg("table"), py::arg("lr")); - m.def("adam_update_for_combined_table", - &dyn_emb::adam_update_for_combined_table, + m.def("adam_update_for_table", &dyn_emb::adam_update_for_table, "Adam optimizer for Dynamic Emb", py::arg("grads"), py::arg("indices"), - py::arg("dev_table"), py::arg("uvm_table"), py::arg("state_dim"), - py::arg("lr"), py::arg("beta1"), py::arg("beta2"), py::arg("eps"), - py::arg("weight_decay"), py::arg("iter_num")); + py::arg("table"), py::arg("state_dim"), py::arg("lr"), py::arg("beta1"), + py::arg("beta2"), py::arg("eps"), py::arg("weight_decay"), + py::arg("iter_num")); - m.def("adagrad_update_for_combined_table", - &dyn_emb::adagrad_update_for_combined_table, + m.def("adagrad_update_for_table", &dyn_emb::adagrad_update_for_table, "Adagrad optimizer for Dynamic Emb", py::arg("grads"), - py::arg("indices"), py::arg("dev_table"), py::arg("uvm_table"), - py::arg("state_dim"), py::arg("lr"), py::arg("eps")); + py::arg("indices"), py::arg("table"), py::arg("state_dim"), + py::arg("lr"), py::arg("eps")); - m.def("rowwise_adagrad_for_combined_table", - &dyn_emb::rowwise_adagrad_for_combined_table, + m.def("rowwise_adagrad_for_table", &dyn_emb::rowwise_adagrad_for_table, "Row Wise Adagrad optimizer for Dynamic Emb", py::arg("grads"), - py::arg("indices"), py::arg("dev_table"), py::arg("uvm_table"), - py::arg("state_dim"), py::arg("lr"), py::arg("eps")); + py::arg("indices"), py::arg("table"), py::arg("state_dim"), + py::arg("lr"), py::arg("eps")); } diff --git a/corelib/dynamicemb/src/optimizer.h b/corelib/dynamicemb/src/optimizer.h index 390582236..e05fa66dc 100644 --- a/corelib/dynamicemb/src/optimizer.h +++ b/corelib/dynamicemb/src/optimizer.h @@ -35,31 +35,22 @@ All rights reserved. # SPDX-License-Identifier: Apache-2.0 namespace dyn_emb { -void sgd_update_for_combined_table(at::Tensor grads, at::Tensor indices, - std::optional dev_table, - std::optional uvm_table, - float const lr); - -void adam_update_for_combined_table(at::Tensor grads, at::Tensor indices, - std::optional dev_table, - std::optional uvm_table, - - int64_t state_dim, const float lr, - const float beta1, const float beta2, - const float eps, const float weight_decay, - const uint32_t iter_num); - -void adagrad_update_for_combined_table(at::Tensor grads, at::Tensor indices, - std::optional dev_table, - std::optional uvm_table, - int64_t state_dim, const float lr, - const float eps); - -void rowwise_adagrad_for_combined_table(at::Tensor grads, at::Tensor indices, - std::optional dev_table, - std::optional uvm_table, - int64_t state_dim, const float lr, - const float eps); +void sgd_update_for_table(at::Tensor grads, at::Tensor indices, + at::Tensor table, float const lr); + +void adam_update_for_table(at::Tensor grads, at::Tensor indices, + at::Tensor table, int64_t state_dim, const float lr, + const float beta1, const float beta2, + const float eps, const float weight_decay, + const uint32_t iter_num); + +void adagrad_update_for_table(at::Tensor grads, at::Tensor indices, + at::Tensor table, int64_t state_dim, + const float lr, const float eps); + +void rowwise_adagrad_for_table(at::Tensor grads, at::Tensor indices, + at::Tensor table, int64_t state_dim, + const float lr, const float eps); } // namespace dyn_emb #endif // OPTIMIZER_H diff --git a/corelib/dynamicemb/src/optimizer_kernel.cuh b/corelib/dynamicemb/src/optimizer_kernel.cuh index a37c211c7..8a30feac2 100644 --- a/corelib/dynamicemb/src/optimizer_kernel.cuh +++ b/corelib/dynamicemb/src/optimizer_kernel.cuh @@ -484,10 +484,9 @@ template __global__ void update4_with_index_kernel(const uint32_t num_keys, const uint32_t dim, - const uint32_t stride, const uint32_t split_index, - const wgrad_t *grad_evs, weight_t *dev_table, - weight_t *uvm_table, index_t *indices, - const bool *masks, OptimizerFunc optimizer) { + const uint32_t stride, const wgrad_t *grad_evs, + weight_t *table, index_t *indices, const bool *masks, + OptimizerFunc optimizer) { constexpr int kWarpSize = 32; const int warp_num_per_block = blockDim.x / kWarpSize; const int warp_id_in_block = threadIdx.x / kWarpSize; @@ -500,12 +499,7 @@ update4_with_index_kernel(const uint32_t num_keys, const uint32_t dim, continue; } - weight_t *weight_ptr = nullptr; - if (index < split_index) { - weight_ptr = dev_table + index * stride; - } else { - weight_ptr = uvm_table + (index - split_index) * stride; - } + weight_t *weight_ptr = table + index * stride; const wgrad_t *grad_ptr = grad_evs + ev_id * dim; OptimizierInput input{grad_ptr, weight_ptr, dim}; @@ -517,10 +511,9 @@ template __global__ void update_with_index_kernel(const uint32_t num_keys, const uint32_t dim, - const uint32_t stride, const uint32_t split_index, - const wgrad_t *grad_evs, weight_t *dev_table, - weight_t *uvm_table, index_t *indices, - const bool *masks, OptimizerFunc optimizer) { + const uint32_t stride, const wgrad_t *grad_evs, + weight_t *table, index_t *indices, const bool *masks, + OptimizerFunc optimizer) { constexpr int kWarpSize = 32; for (uint32_t ev_id = blockIdx.x; ev_id < num_keys; ev_id += gridDim.x) { @@ -530,13 +523,7 @@ update_with_index_kernel(const uint32_t num_keys, const uint32_t dim, if ((!mask) or (index == -1)) { continue; } - weight_t *weight_ptr = nullptr; - if (index < split_index) { - weight_ptr = dev_table + index * stride; - } else { - weight_ptr = uvm_table + (index - split_index) * stride; - } - + weight_t *weight_ptr = table + index * stride; const wgrad_t *grad_ptr = grad_evs + ev_id * dim; OptimizierInput input{grad_ptr, weight_ptr, dim}; diff --git a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py index 22eae8737..4a592f948 100644 --- a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py +++ b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py @@ -482,7 +482,6 @@ def test_forward_train_eval( embedding_dtype=value_type, device_id=device_id, score_strategy=DynamicEmbScoreStrategy.TIMESTAMP, - caching=caching, local_hbm_for_values=1024**3, external_storage=PS, ) @@ -648,7 +647,6 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist embedding_dtype=value_type, device_id=device_id, score_strategy=DynamicEmbScoreStrategy.TIMESTAMP, - caching=caching, local_hbm_for_values=1024**3, external_storage=PS, ) @@ -791,7 +789,6 @@ def test_prefetch_flush_in_cache(opt_type, opt_params, deterministic, PS): embedding_dtype=value_type, device_id=device_id, score_strategy=DynamicEmbScoreStrategy.STEP, - caching=True, local_hbm_for_values=1024**3, external_storage=PS, ) @@ -976,7 +973,6 @@ def test_deterministic_insert(opt_type, opt_params, caching, PS, iteration, batc embedding_dtype=value_type, device_id=device_id, score_strategy=DynamicEmbScoreStrategy.TIMESTAMP, - caching=caching, local_hbm_for_values=init_capacity * dim * 4, external_storage=PS, ) @@ -1093,7 +1089,6 @@ def test_empty_batch(opt_type, opt_params, dim, caching, deterministic, PS): embedding_dtype=value_type, device_id=device_id, score_strategy=DynamicEmbScoreStrategy.TIMESTAMP, - caching=caching, local_hbm_for_values=1024**3, external_storage=PS, ) diff --git a/corelib/dynamicemb/test/unit_tests/incremental_dump/test_batched_dynamicemb_tables.py b/corelib/dynamicemb/test/unit_tests/incremental_dump/test_batched_dynamicemb_tables.py index 33abe9e49..9dfab7c27 100644 --- a/corelib/dynamicemb/test/unit_tests/incremental_dump/test_batched_dynamicemb_tables.py +++ b/corelib/dynamicemb/test/unit_tests/incremental_dump/test_batched_dynamicemb_tables.py @@ -133,7 +133,6 @@ def test_without_eviction( local_hbm_for_values=1024**3, score_strategy=score_strategy, num_of_buckets_per_alloc=num_embeddings[i] // bucket_capacity, - caching=caching, ) for i in range(table_num) ] diff --git a/corelib/dynamicemb/test/unit_tests/test_embedding_admission.py b/corelib/dynamicemb/test/unit_tests/test_embedding_admission.py index bf6bb4280..11a874054 100644 --- a/corelib/dynamicemb/test/unit_tests/test_embedding_admission.py +++ b/corelib/dynamicemb/test/unit_tests/test_embedding_admission.py @@ -311,8 +311,7 @@ def test_admission_strategy_validation( score_strategy ), # Use timestamp for admission use_index_dedup=use_index_dedup, - caching=caching, - cache_capacity_ratio=cache_capacity_ratio if caching else 0.1, + cache_capacity_ratio=cache_capacity_ratio if caching else 1.0, admit_strategy=admission_strategy, # Pass admission strategy ) diff --git a/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.py b/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.py index 10fd07a31..c24e7335c 100644 --- a/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.py +++ b/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.py @@ -199,7 +199,6 @@ def apply_dmp( device: torch.device, score_strategy: DynamicEmbScoreStrategy = DynamicEmbScoreStrategy.LFU, use_index_dedup: bool = False, - caching: bool = False, cache_capacity_ratio: float = 0.5, admit_strategy: AdmissionStrategy = None, ): @@ -213,11 +212,7 @@ def apply_dmp( tmp_type = eb_config.data_type embedding_type_bytes = DATA_TYPE_NUM_BITS[tmp_type] / 8 - emb_num_embeddings = ( - eb_config.num_embeddings * cache_capacity_ratio - if caching - else eb_config.num_embeddings - ) + emb_num_embeddings = eb_config.num_embeddings emb_num_embeddings_next_power_of_2 = 2 ** math.ceil( math.log2(emb_num_embeddings) ) # HKV need embedding vector num is power of 2 @@ -249,12 +244,20 @@ def apply_dmp( else 0 ) - # Include optimizer state in HBM calculation - total_hbm_need = ( + # Include optimizer state in HBM calculation. + # When cache_capacity_ratio < 1, scale down so that only a + # fraction of the table fits in HBM (triggers cache+storage). + # When cache_capacity_ratio >= 1, use full size (all-HBM mode). + full_table_hbm = ( embedding_type_bytes * (dim + optimizer_state_dim) * emb_num_embeddings_next_power_of_2 ) + total_hbm_need = int( + full_table_hbm * cache_capacity_ratio + if cache_capacity_ratio < 1.0 + else full_table_hbm + ) admission_counter = KVCounter( max(1024 * 1024, emb_num_embeddings_next_power_of_2 // 4) @@ -268,7 +271,6 @@ def apply_dmp( ), bucket_capacity=emb_num_embeddings_next_power_of_2, max_capacity=emb_num_embeddings_next_power_of_2, - caching=caching, local_hbm_for_values=1024**3, admit_strategy=admit_strategy, admission_counter=admission_counter, @@ -308,7 +310,6 @@ def create_model( optimizer_kwargs: Dict[str, Any], score_strategy: DynamicEmbScoreStrategy = DynamicEmbScoreStrategy.LFU, use_index_dedup: bool = False, - caching: bool = False, cache_capacity_ratio: float = 0.5, admit_strategy: AdmissionStrategy = None, ): @@ -344,7 +345,6 @@ def create_model( torch.device(f"cuda:{torch.cuda.current_device()}"), score_strategy=score_strategy, use_index_dedup=use_index_dedup, - caching=caching, cache_capacity_ratio=cache_capacity_ratio, admit_strategy=admit_strategy, ) diff --git a/corelib/dynamicemb/test/unit_tests/test_lfu_scores.py b/corelib/dynamicemb/test/unit_tests/test_lfu_scores.py index 42f0e1402..af39a3197 100644 --- a/corelib/dynamicemb/test/unit_tests/test_lfu_scores.py +++ b/corelib/dynamicemb/test/unit_tests/test_lfu_scores.py @@ -270,8 +270,7 @@ def test_lfu_score_validation( optimizer_kwargs=optimizer_kwargs, score_strategy=DynamicEmbScoreStrategy.LFU, use_index_dedup=use_index_dedup, - caching=caching, - cache_capacity_ratio=cache_capacity_ratio if caching else 0.1, + cache_capacity_ratio=cache_capacity_ratio if caching else 1.0, ) # Generate features with frequency tracking diff --git a/examples/hstu/test_utils.py b/examples/hstu/test_utils.py index 55942b68c..c0b3ec012 100755 --- a/examples/hstu/test_utils.py +++ b/examples/hstu/test_utils.py @@ -577,8 +577,6 @@ def create_model( "item": DynamicEmbTableOptions( global_hbm_for_values=1024 * 1024, # 1M HBM (maybe cached) score_strategy=DynamicEmbScoreStrategy.STEP, - caching=pipeline_type - == "prefetch", # when prefetch is enabled, we must enable caching ), } if use_dynamic_emb diff --git a/examples/hstu/training/pretrain_gr_ranking.py b/examples/hstu/training/pretrain_gr_ranking.py index 1b67a3a84..2d1116d14 100644 --- a/examples/hstu/training/pretrain_gr_ranking.py +++ b/examples/hstu/training/pretrain_gr_ranking.py @@ -80,9 +80,7 @@ def main(): args = parser.parse_args() gin.parse_config_file(args.gin_config_file) trainer_args = TrainerArgs() - dataset_args, embedding_args = get_dataset_and_embedding_args( - trainer_args.pipeline_type == "prefetch" - ) + dataset_args, embedding_args = get_dataset_and_embedding_args() network_args = NetworkArgs() optimizer_args = OptimizerArgs() tp_args = TensorModelParallelArgs() diff --git a/examples/hstu/training/pretrain_gr_retrieval.py b/examples/hstu/training/pretrain_gr_retrieval.py index 06266cf10..3323058b3 100644 --- a/examples/hstu/training/pretrain_gr_retrieval.py +++ b/examples/hstu/training/pretrain_gr_retrieval.py @@ -77,9 +77,7 @@ def main(): args = parser.parse_args() gin.parse_config_file(args.gin_config_file) trainer_args = TrainerArgs() - dataset_args, embedding_args = get_dataset_and_embedding_args( - caching=trainer_args.pipeline_type == "prefetch" - ) + dataset_args, embedding_args = get_dataset_and_embedding_args() network_args = NetworkArgs() optimizer_args = OptimizerArgs() tp_args = TensorModelParallelArgs() diff --git a/examples/hstu/training/trainer/utils.py b/examples/hstu/training/trainer/utils.py index 766ea51db..9df4d6f42 100644 --- a/examples/hstu/training/trainer/utils.py +++ b/examples/hstu/training/trainer/utils.py @@ -388,14 +388,11 @@ def create_dynamic_optitons_dict( safe_check_mode=DynamicEmbCheckMode.IGNORE, bucket_capacity=128, training=training, - caching=embedding_args.caching, ) return dynamic_options_dict -def get_dataset_and_embedding_args( - caching: bool = False, -) -> Tuple[ +def get_dataset_and_embedding_args() -> Tuple[ Union[DatasetArgs, BenchmarkDatasetArgs], List[Union[DynamicEmbeddingArgs, EmbeddingArgs]], ]: @@ -449,14 +446,12 @@ def get_dataset_and_embedding_args( table_name="video_id", item_vocab_size_or_capacity=HASH_SIZE, item_vocab_gpu_capacity_ratio=0.5, - caching=caching, ), DynamicEmbeddingArgs( feature_names=["user_id"], table_name="user_id", item_vocab_size_or_capacity=HASH_SIZE, item_vocab_gpu_capacity_ratio=0.5, - caching=caching, ), ] elif dataset_args.dataset_name == "kuairand-1k": @@ -502,14 +497,12 @@ def get_dataset_and_embedding_args( table_name="video_id", item_vocab_size_or_capacity=HASH_SIZE, item_vocab_gpu_capacity_ratio=0.5, - caching=caching, ), DynamicEmbeddingArgs( feature_names=["user_id"], table_name="user_id", item_vocab_size_or_capacity=HASH_SIZE, item_vocab_gpu_capacity_ratio=0.5, - caching=caching, ), ] elif dataset_args.dataset_name == "kuairand-27k": @@ -555,14 +548,12 @@ def get_dataset_and_embedding_args( table_name="video_id", item_vocab_size_or_capacity=32038725, item_vocab_gpu_capacity_ratio=0.5, - caching=caching, ), DynamicEmbeddingArgs( feature_names=["user_id"], table_name="user_id", item_vocab_size_or_capacity=HASH_SIZE, item_vocab_gpu_capacity_ratio=0.5, - caching=caching, ), ] elif dataset_args.dataset_name == "ml-1m": @@ -602,14 +593,12 @@ def get_dataset_and_embedding_args( table_name="movie_id", item_vocab_size_or_capacity=HASH_SIZE, item_vocab_gpu_capacity_ratio=0.5, - caching=caching, ), DynamicEmbeddingArgs( feature_names=["user_id"], table_name="user_id", item_vocab_size_or_capacity=HASH_SIZE, item_vocab_gpu_capacity_ratio=0.5, - caching=caching, ), ] elif dataset_args.dataset_name == "ml-20m": @@ -625,14 +614,12 @@ def get_dataset_and_embedding_args( table_name="movie_id", item_vocab_size_or_capacity=HASH_SIZE, item_vocab_gpu_capacity_ratio=0.5, - caching=caching, ), DynamicEmbeddingArgs( feature_names=["user_id"], table_name="user_id", item_vocab_size_or_capacity=HASH_SIZE, item_vocab_gpu_capacity_ratio=0.5, - caching=True, ), ] else: diff --git a/examples/hstu/utils/gin_config_args.py b/examples/hstu/utils/gin_config_args.py index 1c611deb2..c6d5f9e31 100644 --- a/examples/hstu/utils/gin_config_args.py +++ b/examples/hstu/utils/gin_config_args.py @@ -147,8 +147,6 @@ class DynamicEmbeddingArgs(EmbeddingArgs): item_vocab_gpu_capacity_ratio (Optional[float]): Item vocabulary GPU capacity ratio (lowest priority). Default: None. evict_strategy (str): Eviction strategy: "lru" or "lfu". Default: "lru". - caching (bool): Enable caching on HBM. When caching is enabled, the - global_hbm_for_values indicates the cache size. Default: False. Note: - sharding_type is automatically set to "model_parallel". @@ -166,7 +164,6 @@ class DynamicEmbeddingArgs(EmbeddingArgs): item_vocab_gpu_capacity_ratio: Optional[float] = None evict_strategy: str = "lru" - caching: bool = False def __post_init__(self): self.sharding_type = "model_parallel"