@@ -404,7 +404,7 @@ def test_forward_train_eval(opt_type, opt_params, caching, deterministic, PS):
404404
405405"""
406406For torchrec's adam optimizer, it will increment the optimizer_step in every forward,
407- which will affect the weights update, pay attention to it or try to use `set_optimizer_step()`
407+ which will affect the weights update, pay attention to it or try to use `set_optimizer_step()`
408408 to control(not verified) it.
409409"""
410410
@@ -444,6 +444,7 @@ def test_forward_train_eval(opt_type, opt_params, caching, deterministic, PS):
444444 [
445445 (True , DynamicEmbPoolingMode .NONE , [8 , 8 , 8 ]),
446446 (False , DynamicEmbPoolingMode .NONE , [16 , 16 , 16 ]),
447+ (False , DynamicEmbPoolingMode .NONE , [17 , 17 , 17 ]),
447448 (False , DynamicEmbPoolingMode .SUM , [128 , 32 , 16 ]),
448449 (False , DynamicEmbPoolingMode .MEAN , [4 , 8 , 16 ]),
449450 ],
@@ -467,7 +468,10 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist
467468 max_capacity = 2048
468469
469470 dyn_emb_table_options_list = []
471+ cmp_with_torchrec = True
470472 for dim in dims :
473+ if dim % 4 != 0 :
474+ cmp_with_torchrec = False
471475 dyn_emb_table_options = DynamicEmbTableOptions (
472476 dim = dim ,
473477 init_capacity = max_capacity ,
@@ -492,49 +496,68 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist
492496 ** opt_params ,
493497 )
494498 num_embs = [max_capacity // 2 for d in dims ]
495- stbe = create_split_table_batched_embedding (
496- table_names ,
497- feature_table_map ,
498- OPTIM_TYPE [opt_type ],
499- opt_params ,
500- dims ,
501- num_embs ,
502- POOLING_MODE [pooling_mode ],
503- device ,
504- )
505- init_embedding_tables (stbe , bdeb )
506- """
507- feature number = 4, batch size = 2
508499
509- f0 [0,1], [12],
510- f1 [64,8], [12],
511- f2 [15, 2, 7], [105],
512- f3 [], [0]
513- """
514- for i in range (10 ):
515- indices = torch .tensor (
516- [0 , 1 , 12 , 64 , 8 , 12 , 15 , 2 , 7 , 105 , 0 ], device = device
517- ).to (key_type )
518- offsets = torch .tensor ([0 , 2 , 3 , 5 , 6 , 9 , 10 , 10 , 11 ], device = device ).to (
519- key_type
500+ if cmp_with_torchrec :
501+ stbe = create_split_table_batched_embedding (
502+ table_names ,
503+ feature_table_map ,
504+ OPTIM_TYPE [opt_type ],
505+ opt_params ,
506+ dims ,
507+ num_embs ,
508+ POOLING_MODE [pooling_mode ],
509+ device ,
520510 )
511+ init_embedding_tables (stbe , bdeb )
512+ """
513+ feature number = 4, batch size = 2
514+
515+ f0 [0,1], [12],
516+ f1 [64,8], [12],
517+ f2 [15, 2, 7], [105],
518+ f3 [], [0]
519+ """
520+ for i in range (10 ):
521+ indices = torch .tensor (
522+ [0 , 1 , 12 , 64 , 8 , 12 , 15 , 2 , 7 , 105 , 0 ], device = device
523+ ).to (key_type )
524+ offsets = torch .tensor ([0 , 2 , 3 , 5 , 6 , 9 , 10 , 10 , 11 ], device = device ).to (
525+ key_type
526+ )
521527
522- embs_bdeb = bdeb (indices , offsets )
523- embs_stbe = stbe (indices , offsets )
524-
525- torch .cuda .synchronize ()
526- with torch .no_grad ():
527- torch .testing .assert_close (embs_bdeb , embs_stbe , rtol = 1e-06 , atol = 1e-06 )
528+ embs_bdeb = bdeb (indices , offsets )
529+ embs_stbe = stbe (indices , offsets )
530+
531+ torch .cuda .synchronize ()
532+ with torch .no_grad ():
533+ torch .testing .assert_close (embs_bdeb , embs_stbe , rtol = 1e-06 , atol = 1e-06 )
534+
535+ loss = embs_bdeb .mean ()
536+ loss .backward ()
537+ loss_stbe = embs_stbe .mean ()
538+ loss_stbe .backward ()
539+
540+ torch .cuda .synchronize ()
541+ torch .testing .assert_close (loss , loss_stbe )
542+
543+ print (f"Passed iteration { i } " )
544+ else :
545+ # This scenario will not test correctness, but rather test whether it functions correctly.
546+ for i in range (10 ):
547+ indices = torch .tensor (
548+ [0 , 1 , 12 , 64 , 8 , 12 , 15 , 2 , 7 , 105 , 0 ], device = device
549+ ).to (key_type )
550+ offsets = torch .tensor ([0 , 2 , 3 , 5 , 6 , 9 , 10 , 10 , 11 ], device = device ).to (
551+ key_type
552+ )
528553
529- loss = embs_bdeb .mean ()
530- loss .backward ()
531- loss_stbe = embs_stbe .mean ()
532- loss_stbe .backward ()
554+ embs_bdeb = bdeb (indices , offsets )
555+ loss = embs_bdeb .mean ()
556+ loss .backward ()
533557
534- torch .cuda .synchronize ()
535- torch .testing .assert_close (loss , loss_stbe )
558+ torch .cuda .synchronize ()
536559
537- print (f"Passed iteration { i } " )
560+ print (f"Passed iteration { i } " )
538561
539562 if deterministic :
540563 del os .environ ["DEMB_DETERMINISM_MODE" ]
@@ -853,3 +876,102 @@ def test_deterministic_insert(opt_type, opt_params, caching, PS, iteration, batc
853876
854877 del os .environ ["DEMB_DETERMINISM_MODE" ]
855878 print ("all check passed" )
879+
880+
881+ @pytest .mark .parametrize (
882+ "opt_type,opt_params" ,
883+ [
884+ (EmbOptimType .SGD , {"learning_rate" : 0.3 }),
885+ (
886+ EmbOptimType .EXACT_ROWWISE_ADAGRAD ,
887+ {
888+ "learning_rate" : 0.3 ,
889+ "eps" : 3e-5 ,
890+ },
891+ ),
892+ ],
893+ )
894+ @pytest .mark .parametrize ("dim" , [7 , 8 ])
895+ @pytest .mark .parametrize ("caching" , [True , False ])
896+ @pytest .mark .parametrize ("deterministic" , [True , False ])
897+ @pytest .mark .parametrize ("PS" , [None ])
898+ def test_empty_batch (opt_type , opt_params , dim , caching , deterministic , PS ):
899+ print (
900+ f"step in test_forward_train_eval_empty_batch , opt_type = { opt_type } opt_params = { opt_params } "
901+ )
902+
903+ if deterministic :
904+ os .environ ["DEMB_DETERMINISM_MODE" ] = "ON"
905+
906+ assert torch .cuda .is_available ()
907+ device_id = 0
908+ device = torch .device (f"cuda:{ device_id } " )
909+
910+ dims = [dim , dim , dim ]
911+ table_names = ["table0" , "table1" , "table2" ]
912+ key_type = torch .int64
913+ value_type = torch .float32
914+
915+ init_capacity = 1024
916+ max_capacity = 2048
917+
918+ dyn_emb_table_options_list = []
919+ for dim in dims :
920+ dyn_emb_table_options = DynamicEmbTableOptions (
921+ dim = dim ,
922+ init_capacity = init_capacity ,
923+ max_capacity = max_capacity ,
924+ index_type = key_type ,
925+ embedding_dtype = value_type ,
926+ device_id = device_id ,
927+ score_strategy = DynamicEmbScoreStrategy .TIMESTAMP ,
928+ caching = caching ,
929+ local_hbm_for_values = 1024 ** 3 ,
930+ external_storage = PS ,
931+ )
932+ dyn_emb_table_options_list .append (dyn_emb_table_options )
933+
934+ bdebt = BatchedDynamicEmbeddingTablesV2 (
935+ table_names = table_names ,
936+ table_options = dyn_emb_table_options_list ,
937+ feature_table_map = [0 , 0 , 1 , 2 ],
938+ pooling_mode = DynamicEmbPoolingMode .NONE ,
939+ optimizer = opt_type ,
940+ use_index_dedup = True ,
941+ ** opt_params ,
942+ )
943+ bdebt .enable_prefetch = True
944+ """
945+ feature number = 4, batch size = 1
946+
947+ f0 [],
948+ f1 [],
949+ f2 [],
950+ f3 [],
951+ """
952+ indices = torch .tensor ([], dtype = key_type , device = device )
953+ offsets = torch .tensor ([0 , 0 , 0 , 0 , 0 ], dtype = key_type , device = device )
954+
955+ pretch_stream = torch .cuda .Stream ()
956+ forward_stream = torch .cuda .Stream ()
957+
958+ if caching :
959+ with torch .cuda .stream (pretch_stream ):
960+ bdebt .prefetch (indices , offsets , forward_stream )
961+ torch .cuda .synchronize ()
962+
963+ with torch .cuda .stream (forward_stream ):
964+ res = bdebt (indices , offsets )
965+ torch .cuda .synchronize ()
966+
967+ res .mean ().backward ()
968+
969+ with torch .no_grad ():
970+ bdebt .eval ()
971+ bdebt (indices , offsets )
972+ torch .cuda .synchronize ()
973+
974+ if deterministic :
975+ del os .environ ["DEMB_DETERMINISM_MODE" ]
976+
977+ print ("all check passed" )
0 commit comments