@@ -388,3 +388,195 @@ fn mock_call_when_complex_types() {
388
388
let result = run_test_case ( & test) ;
389
389
assert_passed ( & result) ;
390
390
}
391
+
392
+ #[ test]
393
+ fn mock_calls_when ( ) {
394
+ let test = test_case ! (
395
+ indoc!(
396
+ r#"
397
+ use result::ResultTrait;
398
+ use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCallData};
399
+
400
+ #[starknet::interface]
401
+ trait IMockChecker<TContractState> {
402
+ fn get_thing(ref self: TContractState) -> felt252;
403
+ }
404
+
405
+ #[test]
406
+ fn mock_call_when_one_specific() {
407
+ let calldata = array![420];
408
+
409
+ let contract = declare("MockChecker").unwrap().contract_class();
410
+ let (contract_address, _) = contract.deploy(@calldata).unwrap();
411
+
412
+ let dispatcher = IMockCheckerDispatcher { contract_address };
413
+
414
+ let mock_ret_data = 421;
415
+ let expected_calldata = MockCallData::Values([].span());
416
+ mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 1);
417
+
418
+ let thing = dispatcher.get_thing();
419
+ assert_eq!(thing, 421);
420
+
421
+ let thing = dispatcher.get_thing();
422
+ assert_eq!(thing, 420);
423
+ }
424
+
425
+ #[test]
426
+ fn mock_call_when_twice_specific() {
427
+ let calldata = array![420];
428
+
429
+ let contract = declare("MockChecker").unwrap().contract_class();
430
+ let (contract_address, _) = contract.deploy(@calldata).unwrap();
431
+
432
+ let dispatcher = IMockCheckerDispatcher { contract_address };
433
+
434
+ let mock_ret_data = 421;
435
+ let expected_calldata = MockCallData::Values([].span());
436
+ mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 2);
437
+
438
+ let thing = dispatcher.get_thing();
439
+ assert_eq!(thing, 421);
440
+
441
+ let thing = dispatcher.get_thing();
442
+ assert_eq!(thing, 421);
443
+
444
+ let thing = dispatcher.get_thing();
445
+ assert_eq!(thing, 420);
446
+ }
447
+
448
+ #[test]
449
+ fn mock_call_when_one_any() {
450
+ let calldata = array![420];
451
+
452
+ let contract = declare("MockChecker").unwrap().contract_class();
453
+ let (contract_address, _) = contract.deploy(@calldata).unwrap();
454
+
455
+ let dispatcher = IMockCheckerDispatcher { contract_address };
456
+
457
+ let mock_ret_data = 421;
458
+ mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 1);
459
+
460
+ let thing = dispatcher.get_thing();
461
+ assert_eq!(thing, 421);
462
+
463
+ let thing = dispatcher.get_thing();
464
+ assert_eq!(thing, 420);
465
+ }
466
+
467
+ #[test]
468
+ fn mock_call_when_twice_any() {
469
+ let calldata = array![420];
470
+
471
+ let contract = declare("MockChecker").unwrap().contract_class();
472
+ let (contract_address, _) = contract.deploy(@calldata).unwrap();
473
+
474
+ let dispatcher = IMockCheckerDispatcher { contract_address };
475
+
476
+ let mock_ret_data = 421;
477
+ mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 2);
478
+
479
+ let thing = dispatcher.get_thing();
480
+ assert_eq!(thing, 421);
481
+
482
+ let thing = dispatcher.get_thing();
483
+ assert_eq!(thing, 421);
484
+
485
+ let thing = dispatcher.get_thing();
486
+ assert_eq!(thing, 420);
487
+ }
488
+
489
+ "#
490
+ ) ,
491
+ Contract :: from_code_path(
492
+ "MockChecker" . to_string( ) ,
493
+ Path :: new( "tests/data/contracts/mock_checker.cairo" ) ,
494
+ )
495
+ . unwrap( )
496
+ ) ;
497
+
498
+ let result = run_test_case ( & test) ;
499
+ assert_passed ( & result) ;
500
+ }
501
+
502
+ #[ test]
503
+ fn mock_calls_when_mixed ( ) {
504
+ let test = test_case ! (
505
+ indoc!(
506
+ r#"
507
+ use result::ResultTrait;
508
+ use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCallData};
509
+
510
+ #[starknet::interface]
511
+ trait IMockChecker<TContractState> {
512
+ fn get_thing(ref self: TContractState) -> felt252;
513
+ }
514
+
515
+ #[test]
516
+ fn mock_call_when_one() {
517
+ let calldata = array![420];
518
+
519
+ let contract = declare("MockChecker").unwrap().contract_class();
520
+ let (contract_address, _) = contract.deploy(@calldata).unwrap();
521
+
522
+ let dispatcher = IMockCheckerDispatcher { contract_address };
523
+
524
+ let mock_ret_data = 421;
525
+ let expected_calldata = MockCallData::Values([].span());
526
+ mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 1);
527
+ mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, 422, 1);
528
+
529
+ let thing = dispatcher.get_thing();
530
+ assert_eq!(thing, 421, "Specific calldata");
531
+
532
+ let thing = dispatcher.get_thing();
533
+ assert_eq!(thing, 422, "Any calldata");
534
+
535
+ let thing = dispatcher.get_thing();
536
+ assert_eq!(thing, 420);
537
+ }
538
+
539
+ #[test]
540
+ fn mock_call_when_multi() {
541
+ let calldata = array![420];
542
+
543
+ let contract = declare("MockChecker").unwrap().contract_class();
544
+ let (contract_address, _) = contract.deploy(@calldata).unwrap();
545
+
546
+ let dispatcher = IMockCheckerDispatcher { contract_address };
547
+
548
+ let mock_ret_data = 421;
549
+ let expected_calldata = MockCallData::Values([].span());
550
+ mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 3);
551
+ mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, 422, 2);
552
+
553
+ let thing = dispatcher.get_thing();
554
+ assert_eq!(thing, 421, "1st Specific calldata");
555
+
556
+ let thing = dispatcher.get_thing();
557
+ assert_eq!(thing, 421, "2nd Specific calldata");
558
+
559
+ let thing = dispatcher.get_thing();
560
+ assert_eq!(thing, 421, "3rd Specific calldata");
561
+
562
+ let thing = dispatcher.get_thing();
563
+ assert_eq!(thing, 422, "1st Any calldata");
564
+
565
+ let thing = dispatcher.get_thing();
566
+ assert_eq!(thing, 422, "2nd Any calldata");
567
+
568
+ let thing = dispatcher.get_thing();
569
+ assert_eq!(thing, 420);
570
+ }
571
+ "#
572
+ ) ,
573
+ Contract :: from_code_path(
574
+ "MockChecker" . to_string( ) ,
575
+ Path :: new( "tests/data/contracts/mock_checker.cairo" ) ,
576
+ )
577
+ . unwrap( )
578
+ ) ;
579
+
580
+ let result = run_test_case ( & test) ;
581
+ assert_passed ( & result) ;
582
+ }
0 commit comments