|
7 | 7 |
|
8 | 8 | import finchlite.finch_assembly as asm |
9 | 9 | from finchlite.codegen import NumpyBuffer |
| 10 | +from finchlite.codegen.hashtable import CHashTable, NumbaHashTable |
10 | 11 | from finchlite.finch_assembly import assembly_check_types |
| 12 | +from finchlite.finch_assembly.struct import TupleFType |
11 | 13 | from finchlite.symbolic import FType, ftype |
12 | 14 |
|
13 | 15 |
|
@@ -684,3 +686,99 @@ def test_simple_struct(): |
684 | 686 | ) |
685 | 687 |
|
686 | 688 | assembly_check_types(mod) |
| 689 | + |
| 690 | + |
| 691 | +@pytest.mark.parametrize( |
| 692 | + ["tabletype"], |
| 693 | + [ |
| 694 | + (CHashTable,), |
| 695 | + (NumbaHashTable,), |
| 696 | + ], |
| 697 | +) |
| 698 | +def test_hashtable(tabletype): |
| 699 | + table = tabletype(2, 3) |
| 700 | + |
| 701 | + table_v = asm.Variable("a", ftype(table)) |
| 702 | + table_slt = asm.Slot("a_", ftype(table)) |
| 703 | + |
| 704 | + key_type = table.ftype.key_type |
| 705 | + val_type = table.ftype.value_type |
| 706 | + key_v = asm.Variable("key", key_type) |
| 707 | + val_v = asm.Variable("val", val_type) |
| 708 | + |
| 709 | + mod = asm.Module( |
| 710 | + ( |
| 711 | + asm.Function( |
| 712 | + asm.Variable( |
| 713 | + "setidx", TupleFType.from_tuple(tuple(int for _ in range(3))) |
| 714 | + ), |
| 715 | + (table_v, key_v, val_v), |
| 716 | + asm.Block( |
| 717 | + ( |
| 718 | + asm.Unpack(table_slt, table_v), |
| 719 | + asm.StoreMap( |
| 720 | + table_slt, |
| 721 | + key_v, |
| 722 | + val_v, |
| 723 | + ), |
| 724 | + asm.Repack(table_slt), |
| 725 | + asm.Return(asm.LoadMap(table_slt, key_v)), |
| 726 | + ) |
| 727 | + ), |
| 728 | + ), |
| 729 | + asm.Function( |
| 730 | + asm.Variable("exists", bool), |
| 731 | + (table_v, key_v), |
| 732 | + asm.Block( |
| 733 | + ( |
| 734 | + asm.Unpack(table_slt, table_v), |
| 735 | + asm.Return(asm.ExistsMap(table_slt, key_v)), |
| 736 | + ) |
| 737 | + ), |
| 738 | + ), |
| 739 | + ) |
| 740 | + ) |
| 741 | + assembly_check_types(mod) |
| 742 | + |
| 743 | + |
| 744 | +@pytest.mark.parametrize( |
| 745 | + ["tabletype"], |
| 746 | + [ |
| 747 | + (CHashTable,), |
| 748 | + (NumbaHashTable,), |
| 749 | + ], |
| 750 | +) |
| 751 | +def test_hashtable_fail(tabletype): |
| 752 | + table = tabletype(2, 3) |
| 753 | + |
| 754 | + table_v = asm.Variable("a", ftype(table)) |
| 755 | + table_slt = asm.Slot("a_", ftype(table)) |
| 756 | + |
| 757 | + key_type = table.ftype.key_type |
| 758 | + val_type = table.ftype.value_type |
| 759 | + key_v = asm.Variable("key", key_type) |
| 760 | + val_v = asm.Variable("val", val_type) |
| 761 | + mod = asm.Module( |
| 762 | + ( |
| 763 | + asm.Function( |
| 764 | + asm.Variable( |
| 765 | + "setidx", TupleFType.from_tuple(tuple(int for _ in range(2))) |
| 766 | + ), |
| 767 | + (table_v, key_v, val_v), |
| 768 | + asm.Block( |
| 769 | + ( |
| 770 | + asm.Unpack(table_slt, table_v), |
| 771 | + asm.StoreMap( |
| 772 | + table_slt, |
| 773 | + key_v, |
| 774 | + val_v, |
| 775 | + ), |
| 776 | + asm.Repack(table_slt), |
| 777 | + asm.Return(asm.LoadMap(table_slt, key_v)), |
| 778 | + ) |
| 779 | + ), |
| 780 | + ), |
| 781 | + ) |
| 782 | + ) |
| 783 | + with pytest.raises(asm.AssemblyTypeError): |
| 784 | + assembly_check_types(mod) |
0 commit comments