@@ -425,6 +425,46 @@ def union_int():
425
425
assert python_result [0 ] == pyccel_result [0 ]
426
426
assert set (python_result [1 :]) == set (pyccel_result [1 :])
427
427
428
+ def test_set_intersection_int (python_only_language ):
429
+ def intersection_int ():
430
+ a = {1 ,2 ,3 }
431
+ b = {2 ,3 ,4 }
432
+ c = a .intersection (b )
433
+ return len (c ), c .pop (), c .pop ()
434
+
435
+ epyccel_func = epyccel (intersection_int , language = python_only_language )
436
+ pyccel_result = epyccel_func ()
437
+ python_result = intersection_int ()
438
+ assert python_result [0 ] == pyccel_result [0 ]
439
+ assert set (python_result [1 :]) == set (pyccel_result [1 :])
440
+
441
+ def test_set_intersection_no_args (python_only_language ):
442
+ def intersection_int ():
443
+ a = {1 ,2 ,3 ,4 }
444
+ c = a .intersection ()
445
+ a .add (5 )
446
+ return len (c ), c .pop (), c .pop (), c .pop (), c .pop ()
447
+
448
+ epyccel_func = epyccel (intersection_int , language = python_only_language )
449
+ pyccel_result = epyccel_func ()
450
+ python_result = intersection_int ()
451
+ assert python_result [0 ] == pyccel_result [0 ]
452
+ assert set (python_result [1 :]) == set (pyccel_result [1 :])
453
+
454
+ def test_set_intersection_2_args (python_only_language ):
455
+ def intersection_int ():
456
+ a = {1 ,2 ,3 ,4 }
457
+ b = {5 ,6 ,7 ,2 ,1 ,3 }
458
+ c = {7 ,6 ,10 ,4 ,2 ,3 ,1 }
459
+ d = a .intersection (b , c )
460
+ return len (d ), d .pop (), d .pop (), d .pop ()
461
+
462
+ epyccel_func = epyccel (intersection_int , language = python_only_language )
463
+ pyccel_result = epyccel_func ()
464
+ python_result = intersection_int ()
465
+ assert python_result [0 ] == pyccel_result [0 ]
466
+ assert set (python_result [1 :]) == set (pyccel_result [1 :])
467
+
428
468
@pytest .mark .parametrize ( 'language' , (
429
469
pytest .param ("fortran" , marks = pytest .mark .fortran ),
430
470
pytest .param ("c" , marks = [
@@ -485,6 +525,57 @@ def union_int():
485
525
assert python_result [0 ] == pyccel_result [0 ]
486
526
assert set (python_result [1 :]) == set (pyccel_result [1 :])
487
527
528
+ def test_temporary_set_intersection (python_only_language ):
529
+ def intersection_int ():
530
+ a = {1 ,2 }
531
+ b = {2 }
532
+ d = a .intersection (b ).pop ()
533
+ return d
534
+
535
+ epyccel_func = epyccel (intersection_int , language = python_only_language )
536
+ pyccel_result = epyccel_func ()
537
+ python_result = intersection_int ()
538
+ assert python_result == pyccel_result
539
+
540
+ def test_set_intersection_list (python_only_language ):
541
+ def intersection_list ():
542
+ a = {1.2 , 2.3 , 5.0 }
543
+ b = [1.2 , 5.0 , 4.0 ]
544
+ d = a .intersection (b )
545
+ return len (d ), d .pop (), d .pop ()
546
+
547
+ epyccel_func = epyccel (intersection_list , language = python_only_language )
548
+ pyccel_result = epyccel_func ()
549
+ python_result = intersection_list ()
550
+ assert python_result [0 ] == pyccel_result [0 ]
551
+ assert set (python_result [1 :]) == set (pyccel_result [1 :])
552
+
553
+ def test_set_intersection_tuple (python_only_language ):
554
+ def intersection_tuple ():
555
+ a = {True }
556
+ b = (False , True )
557
+ d = a .intersection (b )
558
+ return len (d ), d .pop ()
559
+
560
+ epyccel_func = epyccel (intersection_tuple , language = python_only_language )
561
+ pyccel_result = epyccel_func ()
562
+ python_result = intersection_tuple ()
563
+ assert python_result [0 ] == pyccel_result [0 ]
564
+ assert set (python_result [1 :]) == set (pyccel_result [1 :])
565
+
566
+ def test_set_intersection_operator (python_only_language ):
567
+ def intersection_int ():
568
+ a = {1 ,2 ,3 ,4 ,8 }
569
+ b = {5 ,2 ,3 ,7 ,8 }
570
+ c = a & b
571
+ return len (c ), c .pop (), c .pop (), c .pop ()
572
+
573
+ epyccel_func = epyccel (intersection_int , language = python_only_language )
574
+ pyccel_result = epyccel_func ()
575
+ python_result = intersection_int ()
576
+ assert python_result [0 ] == pyccel_result [0 ]
577
+ assert set (python_result [1 :]) == set (pyccel_result [1 :])
578
+
488
579
@pytest .mark .parametrize ( 'language' , (
489
580
pytest .param ("fortran" , marks = [
490
581
pytest .mark .xfail (reason = "Update not fully implemented yet. See #2022" ),
@@ -521,6 +612,19 @@ def union_int():
521
612
python_result = union_int ()
522
613
assert python_result == pyccel_result
523
614
615
+ def test_set_intersection_augoperator (python_only_language ):
616
+ def intersection_int ():
617
+ a = {1 ,2 ,3 ,4 }
618
+ b = {2 ,3 ,4 }
619
+ a &= b
620
+ return len (a ), a .pop (), a .pop (), a .pop ()
621
+
622
+ epyccel_func = epyccel (intersection_int , language = python_only_language )
623
+ pyccel_result = epyccel_func ()
624
+ python_result = intersection_int ()
625
+ assert python_result [0 ] == pyccel_result [0 ]
626
+ assert set (python_result [1 :]) == set (pyccel_result [1 :])
627
+
524
628
def test_set_ptr (language ):
525
629
def set_ptr ():
526
630
a = {1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 }
0 commit comments