Skip to content

Commit 3db3fe5

Browse files
committed
Refactor cassette matching logic to use provider names directly and improve precision in file matching
1 parent 5cc760d commit 3db3fe5

File tree

1 file changed

+13
-30
lines changed

1 file changed

+13
-30
lines changed

lib/tasks/vcr.rake

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,8 @@ def record_for_providers(providers, cassette_dir) # rubocop:disable Metrics/AbcS
2929
return
3030
end
3131

32-
# Get URL patterns from the providers themselves
33-
provider_patterns = get_provider_patterns(providers)
34-
35-
puts "Finding cassettes for providers: #{providers.join(', ')}"
36-
3732
# Find and delete matching cassettes
38-
cassettes_to_delete = find_matching_cassettes(cassette_dir, provider_patterns)
33+
cassettes_to_delete = find_matching_cassettes(cassette_dir, providers)
3934

4035
if cassettes_to_delete.empty?
4136
puts 'No cassettes found for the specified providers.'
@@ -51,33 +46,21 @@ def record_for_providers(providers, cassette_dir) # rubocop:disable Metrics/AbcS
5146
puts 'Please review the updated cassettes for sensitive information.'
5247
end
5348

54-
def get_provider_patterns(providers) # rubocop:disable Metrics/MethodLength
55-
provider_patterns = {}
56-
57-
providers.each do |provider_name|
58-
provider_module = RubyLLM::Provider.providers[provider_name.to_sym]
59-
next unless provider_module
60-
61-
# Extract the base URL from the provider's api_base method
62-
api_base = provider_module.api_base.to_s
63-
64-
# Create a regex pattern from the domain
65-
next unless api_base && !api_base.empty?
66-
67-
domain = URI.parse(api_base).host
68-
pattern = Regexp.new(Regexp.escape(domain))
69-
provider_patterns[provider_name] = pattern
70-
end
71-
72-
provider_patterns
73-
end
74-
75-
def find_matching_cassettes(dir, patterns)
49+
def find_matching_cassettes(dir, providers) # rubocop:disable Metrics/MethodLength
7650
cassettes = []
7751

7852
Dir.glob("#{dir}/**/*.yml").each do |file|
79-
content = File.read(file)
80-
cassettes << file if patterns.values.any? { |pattern| content.match?(pattern) }
53+
basename = File.basename(file)
54+
55+
# Precise matching to avoid cross-provider confusion
56+
providers.each do |provider|
57+
# Match only exact provider prefixes
58+
next unless basename =~ /^[^_]*_#{provider}_/ || # For first section like "chat_openai_"
59+
basename =~ /_#{provider}_[^_]+_/ # For middle sections like "_openai_gpt4_"
60+
61+
cassettes << file
62+
break
63+
end
8164
end
8265

8366
cassettes

0 commit comments

Comments
 (0)