@@ -281,25 +281,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
281281 EXPECT_EQ (validResult.getDimension (), 2u );
282282}
283283
284- // Helper to create a minimal function and embedder for getter tests
285- struct GetterTestEnv {
286- Vocab V = {};
284+ // Fixture for IR2Vec tests requiring IR setup and weight management.
285+ class IR2VecTestFixture : public ::testing::Test {
286+ protected:
287+ Vocab V;
287288 LLVMContext Ctx;
288- std::unique_ptr<Module> M = nullptr ;
289+ std::unique_ptr<Module> M;
289290 Function *F = nullptr ;
290291 BasicBlock *BB = nullptr ;
291- Instruction *Add = nullptr ;
292- Instruction *Ret = nullptr ;
293- std::unique_ptr<Embedder> Emb = nullptr ;
292+ Instruction *AddInst = nullptr ;
293+ Instruction *RetInst = nullptr ;
294294
295- GetterTestEnv () {
295+ float OriginalOpcWeight = ::OpcWeight;
296+ float OriginalTypeWeight = ::TypeWeight;
297+ float OriginalArgWeight = ::ArgWeight;
298+
299+ void SetUp () override {
296300 V = {{" add" , {1.0 , 2.0 }},
297301 {" integerTy" , {0.5 , 0.5 }},
298302 {" constant" , {0.2 , 0.3 }},
299303 {" variable" , {0.0 , 0.0 }},
300304 {" unknownTy" , {0.0 , 0.0 }}};
301305
302- M = std::make_unique<Module>(" M" , Ctx);
306+ // Setup IR
307+ M = std::make_unique<Module>(" TestM" , Ctx);
303308 FunctionType *FTy = FunctionType::get (
304309 Type::getInt32Ty (Ctx), {Type::getInt32Ty (Ctx), Type::getInt32Ty (Ctx)},
305310 false );
@@ -308,61 +313,82 @@ struct GetterTestEnv {
308313 Argument *Arg = F->getArg (0 );
309314 llvm::Value *Const = ConstantInt::get (Type::getInt32Ty (Ctx), 42 );
310315
311- Add = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
312- Ret = ReturnInst::Create (Ctx, Add, BB);
316+ AddInst = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
317+ RetInst = ReturnInst::Create (Ctx, AddInst, BB);
318+ }
319+
320+ void setWeights (float OpcWeight, float TypeWeight, float ArgWeight) {
321+ ::OpcWeight = OpcWeight;
322+ ::TypeWeight = TypeWeight;
323+ ::ArgWeight = ArgWeight;
324+ }
313325
314- auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
315- EXPECT_TRUE (static_cast <bool >(Result));
316- Emb = std::move (*Result);
326+ void TearDown () override {
327+ // Restore original global weights
328+ ::OpcWeight = OriginalOpcWeight;
329+ ::TypeWeight = OriginalTypeWeight;
330+ ::ArgWeight = OriginalArgWeight;
317331 }
318332};
319333
320- TEST (IR2VecTest, GetInstVecMap) {
321- GetterTestEnv Env;
322- const auto &InstMap = Env.Emb ->getInstVecMap ();
334+ TEST_F (IR2VecTestFixture, GetInstVecMap) {
335+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
336+ ASSERT_TRUE (static_cast <bool >(Result));
337+ auto Emb = std::move (*Result);
338+
339+ const auto &InstMap = Emb->getInstVecMap ();
323340
324341 EXPECT_EQ (InstMap.size (), 2u );
325- EXPECT_TRUE (InstMap.count (Env. Add ));
326- EXPECT_TRUE (InstMap.count (Env. Ret ));
342+ EXPECT_TRUE (InstMap.count (AddInst ));
343+ EXPECT_TRUE (InstMap.count (RetInst ));
327344
328- EXPECT_EQ (InstMap.at (Env. Add ).size (), 2u );
329- EXPECT_EQ (InstMap.at (Env. Ret ).size (), 2u );
345+ EXPECT_EQ (InstMap.at (AddInst ).size (), 2u );
346+ EXPECT_EQ (InstMap.at (RetInst ).size (), 2u );
330347
331348 // Check values for add: {1.29, 2.31}
332- EXPECT_THAT (InstMap.at (Env. Add ),
349+ EXPECT_THAT (InstMap.at (AddInst ),
333350 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
334351
335352 // Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
336353 // vocab
337- EXPECT_THAT (InstMap.at (Env. Ret ), ElementsAre (0.0 , 0.0 ));
354+ EXPECT_THAT (InstMap.at (RetInst ), ElementsAre (0.0 , 0.0 ));
338355}
339356
340- TEST (IR2VecTest, GetBBVecMap) {
341- GetterTestEnv Env;
342- const auto &BBMap = Env.Emb ->getBBVecMap ();
357+ TEST_F (IR2VecTestFixture, GetBBVecMap) {
358+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
359+ ASSERT_TRUE (static_cast <bool >(Result));
360+ auto Emb = std::move (*Result);
361+
362+ const auto &BBMap = Emb->getBBVecMap ();
343363
344364 EXPECT_EQ (BBMap.size (), 1u );
345- EXPECT_TRUE (BBMap.count (Env. BB ));
346- EXPECT_EQ (BBMap.at (Env. BB ).size (), 2u );
365+ EXPECT_TRUE (BBMap.count (BB));
366+ EXPECT_EQ (BBMap.at (BB).size (), 2u );
347367
348368 // BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
349369 // {1.29, 2.31}
350- EXPECT_THAT (BBMap.at (Env. BB ),
370+ EXPECT_THAT (BBMap.at (BB),
351371 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
352372}
353373
354- TEST (IR2VecTest, GetBBVector) {
355- GetterTestEnv Env;
356- const auto &BBVec = Env.Emb ->getBBVector (*Env.BB );
374+ TEST_F (IR2VecTestFixture, GetBBVector) {
375+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
376+ ASSERT_TRUE (static_cast <bool >(Result));
377+ auto Emb = std::move (*Result);
378+
379+ const auto &BBVec = Emb->getBBVector (*BB);
357380
358381 EXPECT_EQ (BBVec.size (), 2u );
359382 EXPECT_THAT (BBVec,
360383 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
361384}
362385
363- TEST (IR2VecTest, GetFunctionVector) {
364- GetterTestEnv Env;
365- const auto &FuncVec = Env.Emb ->getFunctionVector ();
386+ TEST_F (IR2VecTestFixture, GetFunctionVector) {
387+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
388+ ASSERT_TRUE (static_cast <bool >(Result));
389+ auto Emb = std::move (*Result);
390+
391+ const auto &FuncVec = Emb->getFunctionVector ();
366392
367393 EXPECT_EQ (FuncVec.size (), 2u );
368394
@@ -371,4 +397,45 @@ TEST(IR2VecTest, GetFunctionVector) {
371397 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
372398}
373399
400+ TEST_F (IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
401+ setWeights (1.0 , 1.0 , 1.0 );
402+
403+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
404+ ASSERT_TRUE (static_cast <bool >(Result));
405+ auto Emb = std::move (*Result);
406+
407+ const auto &FuncVec = Emb->getFunctionVector ();
408+
409+ EXPECT_EQ (FuncVec.size (), 2u );
410+
411+ // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
412+ // 0.3] + [0.0 0.0])
413+ EXPECT_THAT (FuncVec,
414+ ElementsAre (DoubleNear (1.7 , 1e-6 ), DoubleNear (2.8 , 1e-6 )));
415+ }
416+
417+ TEST (IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
418+ Vocab InitialVocab = {{" key1" , {1.1 , 2.2 }}, {" key2" , {3.3 , 4.4 }}};
419+ Vocab ExpectedVocab = InitialVocab;
420+ unsigned ExpectedDim = InitialVocab.begin ()->second .size ();
421+
422+ IR2VecVocabAnalysis VocabAnalysis (std::move (InitialVocab));
423+
424+ LLVMContext TestCtx;
425+ Module TestMod (" TestModuleForVocabAnalysis" , TestCtx);
426+ ModuleAnalysisManager MAM;
427+ IR2VecVocabResult Result = VocabAnalysis.run (TestMod, MAM);
428+
429+ EXPECT_TRUE (Result.isValid ());
430+ ASSERT_FALSE (Result.getVocabulary ().empty ());
431+ EXPECT_EQ (Result.getDimension (), ExpectedDim);
432+
433+ const auto &ResultVocab = Result.getVocabulary ();
434+ EXPECT_EQ (ResultVocab.size (), ExpectedVocab.size ());
435+ for (const auto &pair : ExpectedVocab) {
436+ EXPECT_TRUE (ResultVocab.count (pair.first ));
437+ EXPECT_THAT (ResultVocab.at (pair.first ), ElementsAreArray (pair.second ));
438+ }
439+ }
440+
374441} // end anonymous namespace
0 commit comments