@@ -172,12 +172,14 @@ char CmpLogRoutines::ID = 0;
172
172
#endif
173
173
174
174
bool CmpLogRoutines::hookRtns (Module &M) {
175
- std::vector<CallInst *> calls, llvmStdStd, llvmStdC, gccStdStd, gccStdC;
176
- LLVMContext &C = M.getContext ();
175
+ std::vector<CallInst *> calls, llvmStdStd, llvmStdC, gccStdStd, gccStdC,
176
+ Memcmp, Strcmp, Strncmp;
177
+ LLVMContext &C = M.getContext ();
177
178
178
179
Type *VoidTy = Type::getVoidTy (C);
179
180
// PointerType *VoidPtrTy = PointerType::get(VoidTy, 0);
180
181
IntegerType *Int8Ty = IntegerType::getInt8Ty (C);
182
+ IntegerType *Int64Ty = IntegerType::getInt64Ty (C);
181
183
PointerType *i8PtrTy = PointerType::get (Int8Ty, 0 );
182
184
183
185
#if LLVM_VERSION_MAJOR < 9
@@ -269,6 +271,60 @@ bool CmpLogRoutines::hookRtns(Module &M) {
269
271
FunctionCallee cmplogGccStdC = c4;
270
272
#endif
271
273
274
+ #if LLVM_VERSION_MAJOR >= 9
275
+ FunctionCallee
276
+ #else
277
+ Constant *
278
+ #endif
279
+ c5 = M.getOrInsertFunction (" __cmplog_rtn_hook_n" , VoidTy, i8PtrTy,
280
+ i8PtrTy, Int64Ty
281
+ #if LLVM_VERSION_MAJOR < 5
282
+ ,
283
+ NULL
284
+ #endif
285
+ );
286
+ #if LLVM_VERSION_MAJOR >= 9
287
+ FunctionCallee cmplogHookFnN = c5;
288
+ #else
289
+ Function *cmplogHookFnN = cast<Function>(c5);
290
+ #endif
291
+
292
+ #if LLVM_VERSION_MAJOR >= 9
293
+ FunctionCallee
294
+ #else
295
+ Constant *
296
+ #endif
297
+ c6 = M.getOrInsertFunction (" __cmplog_rtn_hook_strn" , VoidTy, i8PtrTy,
298
+ i8PtrTy, Int64Ty
299
+ #if LLVM_VERSION_MAJOR < 5
300
+ ,
301
+ NULL
302
+ #endif
303
+ );
304
+ #if LLVM_VERSION_MAJOR >= 9
305
+ FunctionCallee cmplogHookFnStrN = c6;
306
+ #else
307
+ Function *cmplogHookFnStrN = cast<Function>(c6);
308
+ #endif
309
+
310
+ #if LLVM_VERSION_MAJOR >= 9
311
+ FunctionCallee
312
+ #else
313
+ Constant *
314
+ #endif
315
+ c7 = M.getOrInsertFunction (" __cmplog_rtn_hook_str" , VoidTy, i8PtrTy,
316
+ i8PtrTy
317
+ #if LLVM_VERSION_MAJOR < 5
318
+ ,
319
+ NULL
320
+ #endif
321
+ );
322
+ #if LLVM_VERSION_MAJOR >= 9
323
+ FunctionCallee cmplogHookFnStr = c7;
324
+ #else
325
+ Function *cmplogHookFnStr = cast<Function>(c7);
326
+ #endif
327
+
272
328
/* iterate over all functions, bbs and instruction and add suitable calls */
273
329
for (auto &F : M) {
274
330
if (isIgnoreFunction (&F)) { continue ; }
@@ -283,12 +339,87 @@ bool CmpLogRoutines::hookRtns(Module &M) {
283
339
if (callInst->getCallingConv () != llvm::CallingConv::C) { continue ; }
284
340
285
341
FunctionType *FT = Callee->getFunctionType ();
342
+ std::string FuncName = Callee->getName ().str ();
286
343
287
344
bool isPtrRtn = FT->getNumParams () >= 2 &&
288
345
!FT->getReturnType ()->isVoidTy () &&
289
346
FT->getParamType (0 ) == FT->getParamType (1 ) &&
290
347
FT->getParamType (0 )->isPointerTy ();
291
348
349
+ bool isPtrRtnN = FT->getNumParams () >= 3 &&
350
+ !FT->getReturnType ()->isVoidTy () &&
351
+ FT->getParamType (0 ) == FT->getParamType (1 ) &&
352
+ FT->getParamType (0 )->isPointerTy () &&
353
+ FT->getParamType (2 )->isIntegerTy ();
354
+ if (isPtrRtnN) {
355
+ auto intTyOp =
356
+ dyn_cast<IntegerType>(callInst->getArgOperand (2 )->getType ());
357
+ if (intTyOp) {
358
+ if (intTyOp->getBitWidth () != 32 &&
359
+ intTyOp->getBitWidth () != 64 ) {
360
+ isPtrRtnN = false ;
361
+ }
362
+ }
363
+ }
364
+
365
+ bool isMemcmp =
366
+ (!FuncName.compare (" memcmp" ) || !FuncName.compare (" bcmp" ) ||
367
+ !FuncName.compare (" CRYPTO_memcmp" ) ||
368
+ !FuncName.compare (" OPENSSL_memcmp" ) ||
369
+ !FuncName.compare (" memcmp_const_time" ) ||
370
+ !FuncName.compare (" memcmpct" ));
371
+ isMemcmp &= FT->getNumParams () == 3 &&
372
+ FT->getReturnType ()->isIntegerTy (32 ) &&
373
+ FT->getParamType (0 )->isPointerTy () &&
374
+ FT->getParamType (1 )->isPointerTy () &&
375
+ FT->getParamType (2 )->isIntegerTy ();
376
+
377
+ bool isStrcmp =
378
+ (!FuncName.compare (" strcmp" ) || !FuncName.compare (" xmlStrcmp" ) ||
379
+ !FuncName.compare (" xmlStrEqual" ) ||
380
+ !FuncName.compare (" g_strcmp0" ) ||
381
+ !FuncName.compare (" curl_strequal" ) ||
382
+ !FuncName.compare (" strcsequal" ) ||
383
+ !FuncName.compare (" strcasecmp" ) ||
384
+ !FuncName.compare (" stricmp" ) ||
385
+ !FuncName.compare (" ap_cstr_casecmp" ) ||
386
+ !FuncName.compare (" OPENSSL_strcasecmp" ) ||
387
+ !FuncName.compare (" xmlStrcasecmp" ) ||
388
+ !FuncName.compare (" g_strcasecmp" ) ||
389
+ !FuncName.compare (" g_ascii_strcasecmp" ) ||
390
+ !FuncName.compare (" Curl_strcasecompare" ) ||
391
+ !FuncName.compare (" Curl_safe_strcasecompare" ) ||
392
+ !FuncName.compare (" cmsstrcasecmp" ) ||
393
+ !FuncName.compare (" strstr" ) ||
394
+ !FuncName.compare (" g_strstr_len" ) ||
395
+ !FuncName.compare (" ap_strcasestr" ) ||
396
+ !FuncName.compare (" xmlStrstr" ) ||
397
+ !FuncName.compare (" xmlStrcasestr" ) ||
398
+ !FuncName.compare (" g_str_has_prefix" ) ||
399
+ !FuncName.compare (" g_str_has_suffix" ));
400
+ isStrcmp &=
401
+ FT->getNumParams () == 2 && FT->getReturnType ()->isIntegerTy (32 ) &&
402
+ FT->getParamType (0 ) == FT->getParamType (1 ) &&
403
+ FT->getParamType (0 ) == IntegerType::getInt8PtrTy (M.getContext ());
404
+
405
+ bool isStrncmp = (!FuncName.compare (" strncmp" ) ||
406
+ !FuncName.compare (" xmlStrncmp" ) ||
407
+ !FuncName.compare (" curl_strnequal" ) ||
408
+ !FuncName.compare (" strncasecmp" ) ||
409
+ !FuncName.compare (" strnicmp" ) ||
410
+ !FuncName.compare (" ap_cstr_casecmpn" ) ||
411
+ !FuncName.compare (" OPENSSL_strncasecmp" ) ||
412
+ !FuncName.compare (" xmlStrncasecmp" ) ||
413
+ !FuncName.compare (" g_ascii_strncasecmp" ) ||
414
+ !FuncName.compare (" Curl_strncasecompare" ) ||
415
+ !FuncName.compare (" g_strncasecmp" ));
416
+ isStrncmp &= FT->getNumParams () == 3 &&
417
+ FT->getReturnType ()->isIntegerTy (32 ) &&
418
+ FT->getParamType (0 ) == FT->getParamType (1 ) &&
419
+ FT->getParamType (0 ) ==
420
+ IntegerType::getInt8PtrTy (M.getContext ()) &&
421
+ FT->getParamType (2 )->isIntegerTy ();
422
+
292
423
bool isGccStdStringStdString =
293
424
Callee->getName ().find (" __is_charIT_EE7__value" ) !=
294
425
std::string::npos &&
@@ -336,10 +467,13 @@ bool CmpLogRoutines::hookRtns(Module &M) {
336
467
*/
337
468
338
469
if (isGccStdStringCString || isGccStdStringStdString ||
339
- isLlvmStdStringStdString || isLlvmStdStringCString) {
340
- isPtrRtn = false ;
470
+ isLlvmStdStringStdString || isLlvmStdStringCString || isMemcmp ||
471
+ isStrcmp || isStrncmp) {
472
+ isPtrRtnN = isPtrRtn = false ;
341
473
}
342
474
475
+ if (isPtrRtnN) { isPtrRtn = false ; }
476
+
343
477
if (isPtrRtn) { calls.push_back (callInst); }
344
478
if (isGccStdStringStdString) { gccStdStd.push_back (callInst); }
345
479
if (isGccStdStringCString) { gccStdC.push_back (callInst); }
@@ -351,9 +485,9 @@ bool CmpLogRoutines::hookRtns(Module &M) {
351
485
}
352
486
353
487
if (!calls.size () && !gccStdStd.size () && !gccStdC.size () &&
354
- !llvmStdStd.size () && !llvmStdC.size ()) {
488
+ !llvmStdStd.size () && !llvmStdC.size () && !Memcmp.size () &&
489
+ Strcmp.size () && Strncmp.size ())
355
490
return false ;
356
- }
357
491
358
492
for (auto &callInst : calls) {
359
493
Value *v1P = callInst->getArgOperand (0 ), *v2P = callInst->getArgOperand (1 );
@@ -372,6 +506,64 @@ bool CmpLogRoutines::hookRtns(Module &M) {
372
506
// errs() << callInst->getCalledFunction()->getName() << "\n";
373
507
}
374
508
509
+ for (auto &callInst : Memcmp) {
510
+ Value *v1P = callInst->getArgOperand (0 ), *v2P = callInst->getArgOperand (1 ),
511
+ *v3P = callInst->getArgOperand (2 );
512
+
513
+ IRBuilder<> IRB (callInst->getParent ());
514
+
515
+ std::vector<Value *> args;
516
+ Value *v1Pcasted = IRB.CreatePointerCast (v1P, i8PtrTy);
517
+ Value *v2Pcasted = IRB.CreatePointerCast (v2P, i8PtrTy);
518
+ Value *v3Pbitcast = IRB.CreateBitCast (
519
+ v3P, IntegerType::get (C, v3P->getType ()->getPrimitiveSizeInBits ()));
520
+ Value *v3Pcasted =
521
+ IRB.CreateIntCast (v3Pbitcast, IntegerType::get (C, 64 ), false );
522
+ args.push_back (v1Pcasted);
523
+ args.push_back (v2Pcasted);
524
+ args.push_back (v3Pcasted);
525
+
526
+ IRB.CreateCall (cmplogHookFnN, args);
527
+
528
+ // errs() << callInst->getCalledFunction()->getName() << "\n";
529
+ }
530
+
531
+ for (auto &callInst : Strcmp) {
532
+ Value *v1P = callInst->getArgOperand (0 ), *v2P = callInst->getArgOperand (1 );
533
+
534
+ IRBuilder<> IRB (callInst->getParent ());
535
+ std::vector<Value *> args;
536
+ Value *v1Pcasted = IRB.CreatePointerCast (v1P, i8PtrTy);
537
+ Value *v2Pcasted = IRB.CreatePointerCast (v2P, i8PtrTy);
538
+ args.push_back (v1Pcasted);
539
+ args.push_back (v2Pcasted);
540
+
541
+ IRB.CreateCall (cmplogHookFnStr, args);
542
+
543
+ // errs() << callInst->getCalledFunction()->getName() << "\n";
544
+ }
545
+
546
+ for (auto &callInst : Strncmp) {
547
+ Value *v1P = callInst->getArgOperand (0 ), *v2P = callInst->getArgOperand (1 ),
548
+ *v3P = callInst->getArgOperand (2 );
549
+
550
+ IRBuilder<> IRB (callInst->getParent ());
551
+ std::vector<Value *> args;
552
+ Value *v1Pcasted = IRB.CreatePointerCast (v1P, i8PtrTy);
553
+ Value *v2Pcasted = IRB.CreatePointerCast (v2P, i8PtrTy);
554
+ Value *v3Pbitcast = IRB.CreateBitCast (
555
+ v3P, IntegerType::get (C, v3P->getType ()->getPrimitiveSizeInBits ()));
556
+ Value *v3Pcasted =
557
+ IRB.CreateIntCast (v3Pbitcast, IntegerType::get (C, 64 ), false );
558
+ args.push_back (v1Pcasted);
559
+ args.push_back (v2Pcasted);
560
+ args.push_back (v3Pcasted);
561
+
562
+ IRB.CreateCall (cmplogHookFnStrN, args);
563
+
564
+ // errs() << callInst->getCalledFunction()->getName() << "\n";
565
+ }
566
+
375
567
for (auto &callInst : gccStdStd) {
376
568
Value *v1P = callInst->getArgOperand (0 ), *v2P = callInst->getArgOperand (1 );
377
569
0 commit comments