1+ println (" *" ^ 20 ," \n Hello from Julia!\n " ," *" ^ 20 )
2+
3+ const MSet = Dict{Int,Int}
4+
5+ function build_set_and_tids_mset (token_ids)
6+ tids_mset = MSet ()
7+
8+ for tid in token_ids
9+ # this skips already matched token ids that are -1
10+ tid == - 1 && continue
11+ tids_mset[tid] = get (()-> 0 , tids_mset, tid) + 1
12+ end
13+
14+ return BitSet (keys (tids_mset)), tids_mset
15+ end
16+
17+ set_counter (set:: BitSet ) = length (set)
18+ set_counter (set:: MSet ) = sum (values (set))
19+
20+ set_intersector (set1:: BitSet , set2:: BitSet ) = intersect (set1, set2)
21+ function set_intersector (set1:: MSet , set2:: MSet )
22+ length (set1) > length (set2) && ((set1,set2) = (set2,set1))
23+ ret = MSet ()
24+ for (k, v) in set1
25+ ! haskey (set2, k) && continue
26+ ret[k] = min (v, set2[k])
27+ end
28+ return ret
29+ end
30+
31+ set_high_intersection_filter (set:: BitSet , cutoff) = filter (<= (cutoff), set)
32+ set_high_intersection_filter (set:: MSet , cutoff) = filter (pair -> pair. first<= cutoff, set)
33+
34+ function compare_token_sets (qset, iset, len_legalese, min_matched_length_high, min_matched_length; minimum_containment= 0 , high_resemblance_threshold= 0.8 )
35+ intersection = set_intersector (qset, iset)
36+ length (intersection) == 0 && return nothing ,nothing
37+ high_intersection = set_high_intersection_filter (intersection, len_legalese)
38+ length (high_intersection) == 0 && return nothing ,nothing
39+ length (set_counter (high_intersection)) < min_matched_length_high && return nothing ,nothing
40+
41+ rule_length = set_counter (iset)
42+ matched_length = set_counter (intersection)
43+ matched_length < min_matched_length && return nothing , nothing
44+
45+ union_len = set_counter (qset) + rule_length - matched_length
46+ resemblance = matched_length / union_len
47+ containment = matched_length / rule_length
48+ containment < minimum_containment && return nothing , nothing
49+
50+ amplified_resemblance = resemblance^ 2
51+ score_vec1 = (;
52+ is_highly_resemblant= round (resemblance; digits= 1 ) >= high_resemblance_threshold,
53+ containment= round (containment; digits= 1 ),
54+ resemblance= round (amplified_resemblance; digits= 1 ),
55+ matched_length= round (Int, matched_length / 20 ))
56+
57+ score_vec2 = (;
58+ is_highly_resemblant= resemblance >= high_resemblance_threshold,
59+ containment= containment,
60+ resemblance= amplified_resemblance,
61+ matched_length= matched_length
62+ )
63+
64+ return (score_vec1,score_vec2), high_intersection
65+
66+ end
67+
68+ const ScoreVector = @NamedTuple {is_highly_resemblant:: Bool , containment:: Float64 , resemblance:: Float64 , matched_length:: Int64 }
69+
70+ struct RuleInfo
71+ min_matched_length_unique:: Int
72+ min_matched_length:: Int
73+ min_high_matched_length_unique:: Int
74+ min_high_matched_length:: Int
75+ minimum_containment:: Float64
76+ end
77+
78+ function convert_rule_list (rules_by_rid)
79+ return [RuleInfo (
80+ pyconvert (Any, r. get_min_matched_length (true )),
81+ pyconvert (Any, r. get_min_matched_length (false )),
82+ pyconvert (Any, r. get_min_high_matched_length (true )),
83+ pyconvert (Any, r. get_min_high_matched_length (false )),
84+ pyconvert (Any, r. _minimum_containment)) for r in rules_by_rid]
85+ end
86+
87+ function convert_set_list (sets)
88+ return [isnothing (set) ? nothing : BitSet (set) for set in sets]
89+ end
90+
91+ function convert_mset_list (msets)
92+ return [isnothing (mset) ? nothing : MSet (mset) for mset in msets]
93+ end
94+
95+ function compute_candidates (token_ids, len_legalese, rules_by_rid, sets_by_rid, msets_by_rid,
96+ matchable_rids, top= 50 , high_resemblance= false , high_resemblance_threshold= 0.8 )
97+ # collect query-side sets used for matching
98+ qset, qmset = build_set_and_tids_mset (token_ids)
99+
100+ # @info "compute_candidates" typeof(token_ids) typeof(len_legalese) typeof(rules_by_rid) typeof(sets_by_rid) typeof(msets_by_rid) typeof(matchable_rids) typeof(top) typeof(high_resemblance) typeof(high_resemblance_threshold)
101+ # typeof(token_ids) = Vector{Int64} (alias for Array{Int64, 1})
102+ # typeof(len_legalese) = Int64
103+ # typeof(rules_by_rid) = Vector{RuleInfo} (alias for Array{RuleInfo, 1})
104+ # typeof(sets_by_rid) = Vector{Union{Nothing, BitSet}} (alias for Array{Union{Nothing, BitSet}, 1})
105+ # typeof(msets_by_rid) = Vector{Union{Nothing, Dict{Int64, Int64}}} (alias for Array{Union{Nothing, Dict{Int64, Int64}}, 1})
106+ # typeof(matchable_rids) = BitSet
107+ # typeof(top) = Int64
108+ # typeof(high_resemblance) = Bool
109+ # typeof(high_resemblance_threshold) = Float64
110+
111+
112+ # perform two steps of ranking:
113+ # step one with tid sets and step two with tid multisets for refinement
114+
115+ # ###########################################################################
116+ # step 1 is on token id sets:
117+ # ###########################################################################
118+
119+ sortable_candidates = Tuple{Tuple{ScoreVector,ScoreVector}, Int, RuleInfo, BitSet}[]
120+
121+ for (rid, rule) in enumerate (rules_by_rid)
122+ rid -= 1 # julia python compat
123+ rid in matchable_rids || continue
124+
125+ scores_vectors, high_set_intersection = compare_token_sets (
126+ qset,
127+ sets_by_rid[rid+ 1 ],
128+ len_legalese,
129+ rule. min_high_matched_length_unique,
130+ rule. min_matched_length_unique;
131+ minimum_containment= rule. minimum_containment,
132+ high_resemblance_threshold)
133+
134+ if ! isnothing (scores_vectors)
135+ svr, svf = scores_vectors
136+ if (! high_resemblance || (high_resemblance && svr. is_highly_resemblant && svf. is_highly_resemblant))
137+ # @info "" scores_vectors rid rule high_set_intersection
138+ push! (sortable_candidates, (scores_vectors, rid, rule, high_set_intersection))
139+ end
140+ end
141+ end
142+
143+ length (sortable_candidates) == 0 && return sortable_candidates
144+
145+ sort! (sortable_candidates; rev= true )
146+
147+ # ###################################################################
148+ # step 2 is on tids multisets
149+ # ###################################################################
150+ # keep only the 10 x top candidates
151+ sortable_candidates_new = eltype (sortable_candidates)[]
152+ for (k , (_score_vectors, rid, rule, high_set_intersection)) in enumerate (sortable_candidates)
153+ k >= 10 * top && break
154+ scores_vectors, _intersection = compare_token_sets (
155+ qmset,
156+ msets_by_rid[rid+ 1 ],
157+ len_legalese,
158+ rule. min_high_matched_length,
159+ rule. min_matched_length;
160+ minimum_containment= rule. minimum_containment,
161+ high_resemblance_threshold)
162+
163+ if ! isnothing (scores_vectors)
164+ svr, svf = scores_vectors
165+ if (! high_resemblance || (high_resemblance && svr. is_highly_resemblant && svf. is_highly_resemblant))
166+ push! (sortable_candidates_new, (scores_vectors, rid, rule, high_set_intersection))
167+ end
168+ end
169+ end
170+
171+ length (sortable_candidates_new) == 0 && return sortable_candidates_new
172+
173+ # rank candidates
174+ return sort! (sortable_candidates_new; rev= true )[1 : min (top, length (sortable_candidates_new))]
175+ end
0 commit comments