@@ -5252,6 +5252,108 @@ def setup_kvec_input(k_vec, k_vec_dict, symmetry=None):
52525252 dict: New entries and those need to be corrected in the data
52535253 to be used in the post request.
52545254 """
5255+ from itertools import product
5256+
5257+ def extract_coeff_and_var (s ):
5258+ """Extract coefficient and var from strings like '2/3a', '1/3b', '1/2', 'a', etc.
5259+ """
5260+ s = s .strip ()
5261+
5262+ if not s :
5263+ return Fraction (1 ), ''
5264+
5265+ # Split into coefficient and character parts
5266+ # Find the last alphabetic character
5267+ char_match = re .search (r'[a-zA-Z]' , s )
5268+
5269+ if char_match :
5270+ # There's a character
5271+ char_pos = char_match .start ()
5272+ coeff_part = s [:char_pos ]
5273+ char_part = s [char_pos :]
5274+
5275+ # Validate that character part is just one letter
5276+ if not re .match (r'^[a-zA-Z]$' , char_part ):
5277+ print (f"Error: Invalid character part: { char_part } " )
5278+ else :
5279+ # No character
5280+ coeff_part = s
5281+ char_part = ''
5282+
5283+ # Parse coefficient
5284+ if not coeff_part or coeff_part in ['+' , '' ]:
5285+ coefficient = Fraction (1 )
5286+ elif coeff_part == '-' :
5287+ coefficient = Fraction (- 1 )
5288+ else :
5289+ try :
5290+ coefficient = Fraction (coeff_part )
5291+ except ValueError :
5292+ try :
5293+ coefficient = Fraction (float (coeff_part )).limit_denominator ()
5294+ except ValueError as e :
5295+ print (f"Error converting coefficient: { e } " )
5296+
5297+ return coefficient , char_part
5298+
5299+ def generate_all_combinations (expressions , max_var_value = 100 ):
5300+ """Generate all possible combinations of variable assignments and calculate sums.
5301+
5302+ Args:
5303+ expressions (list): List of strings like ['2/3a', '1/3b', '1/2', 'a']
5304+ max_var_value (int): Maximum value to assign to variables (1 to max_var_value)
5305+
5306+ Returns:
5307+ list: List of tuples (sum_value, variable_assignments)
5308+ """
5309+ parsed_terms = []
5310+ variables = set ()
5311+
5312+ for expr in expressions :
5313+ coeff , var = extract_coeff_and_var (expr )
5314+ parsed_terms .append ((coeff , var ))
5315+ if var :
5316+ variables .add (var )
5317+
5318+ variables = sorted (list (variables ))
5319+
5320+ if not variables :
5321+ total_sum = sum (coeff for coeff , var in parsed_terms if not var )
5322+ yield (total_sum , {})
5323+ return
5324+
5325+ # Generate combinations one at a time
5326+ var_ranges = [range (1 , max_var_value + 1 ) for _ in variables ]
5327+
5328+ for var_values in product (* var_ranges ):
5329+ var_assignment = dict (zip (variables , var_values ))
5330+
5331+ total_sum = Fraction (0 )
5332+ for coeff , var in parsed_terms :
5333+ if var :
5334+ total_sum += coeff * var_assignment [var ]
5335+ else :
5336+ total_sum += coeff
5337+
5338+ yield (total_sum , var_assignment )
5339+
5340+ def get_unique_sums (expressions , max_var_value = 100 ):
5341+ """
5342+ Get unique sum values and count their occurrences
5343+
5344+ Returns:
5345+ dict: {sum_value: count}
5346+ """
5347+ sum_counts = {}
5348+
5349+ for total_sum , var_assignment in generate_all_combinations (expressions , max_var_value ):
5350+ if total_sum in sum_counts :
5351+ sum_counts [total_sum ] += 1
5352+ else :
5353+ sum_counts [total_sum ] = 1
5354+
5355+ return sum_counts
5356+
52555357 def match_vector_pattern (k_vec , k_vec_dict , symmetry = None ):
52565358 """Check the k-vector against the standard form in isodistort.
52575359
@@ -5264,61 +5366,24 @@ def match_vector_pattern(k_vec, k_vec_dict, symmetry=None):
52645366 str: The standard k-vector form in isodistort.
52655367 """
52665368 from itertools import permutations
5267-
5268- all_matches = list ()
5269-
5369+
5370+ all_matches = {}
5371+
52705372 for desc , pattern in k_vec_dict .items ():
52715373 if len (pattern ) != len (k_vec ):
52725374 continue
5273-
5274- # Generate sequences to check based on symmetry
5275- if symmetry == 'cubic' :
5276- # For cubic symmetry, check all permutations of k_vec
5375+
5376+ if symmetry in ['cubic' , 'rhombohedral' ]:
52775377 k_vec_sequences = list (permutations (k_vec ))
5378+ elif symmetry in ['hexagonal' , 'trigonal' , 'tetragonal' ]:
5379+ k_vec_sequences = [k_vec , (k_vec [1 ], k_vec [0 ], k_vec [2 ])]
52785380 else :
5279- # For other symmetries, maintain original order
52805381 k_vec_sequences = [k_vec ]
52815382
5282- # Check if any permutation matches the pattern
5283- pattern_matched = False
52845383 for k_vec_seq in k_vec_sequences :
5285- placeholders = {}
5286- match = True
5287- for p_val , k_val in zip (pattern , k_vec_seq ):
5288- if isinstance (p_val , str ):
5289- if p_val .isalpha ():
5290- if p_val not in placeholders :
5291- placeholders [p_val ] = k_val
5292- elif placeholders [p_val ] != k_val :
5293- match = False
5294- break
5295- else :
5296- match = False
5297- break
5298- else :
5299- if p_val != k_val :
5300- match = False
5301- break
5302-
5303- if match :
5304- pattern_matched = True
5305- break
5306-
5307- if pattern_matched :
5308- all_matches .append (desc )
5309-
5310- idp_params_num = 3
5311- for match in all_matches :
5312- params = list ()
5313- for item in k_vec_dict [match ]:
5314- if isinstance (item , str ):
5315- params .append (item )
5316- idp_params = list (set (params ))
5317- if len (idp_params ) <= idp_params_num :
5318- idp_params_num = len (idp_params )
5319- final_match = match
5320-
5321- return final_match
5384+ pass
5385+
5386+ return
53225387
53235388 k_vec_form = match_vector_pattern (k_vec , k_vec_dict , symmetry = "cubic" )
53245389
@@ -5489,7 +5554,7 @@ def match_vector_pattern(k_vec, k_vec_dict, symmetry=None):
54895554
54905555 kvec_dict = grab_all_kvecs (out2 )
54915556
5492- data_update = setup_kvec_input (kpoint_frac , kvec_dict )
5557+ data_update = setup_kvec_input (kpoint_frac , kvec_dict , symmetry = lat_sym )
54935558 for key , value in data_update .items ():
54945559 data [key ] = value
54955560
0 commit comments