Skip to content

Commit fe36991

Browse files
Merge pull request #170 from NotJoeMartinez/fix_or_and_search
Fix or and search
2 parents 6d6398e + 5f1b354 commit fe36991

File tree

2 files changed

+175
-42
lines changed

2 files changed

+175
-42
lines changed

yt_fts/db_utils.py

Lines changed: 168 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sqlite3
22
import sys
3+
import re
34

45
from sqlite_utils import Database
56
from rich.console import Console
@@ -112,54 +113,188 @@ def get_channels():
112113
return db.execute("SELECT ROWID, channel_id, channel_name, channel_url FROM Channels").fetchall()
113114

114115

115-
def search_channel(channel_id, text, limit=None):
116-
db = Database(get_db_path())
116+
def escape_fts5_query(query):
117+
special_chars = ['"', '*', '(', ')', '-', '+']
118+
for char in special_chars:
119+
query = query.replace(char, f'"{char}"')
120+
return query
121+
122+
123+
def escape_fts5_term(term):
124+
special_chars = ['"', '*', '(', ')', '+', '-']
125+
for char in special_chars:
126+
term = term.replace(char, f'"{char}"')
127+
return f'"{term}"'
128+
117129

118-
words = text.split()
119-
processed_words = []
120-
for word in words:
121-
if '"' in word:
122-
processed_words.append(word.replace('"', '""'))
130+
def parse_query(query):
131+
terms = re.findall(r'"[^"]*"|\S+', query)
132+
parsed_query = []
133+
for term in terms:
134+
if term in ('AND', 'OR'):
135+
parsed_query.append(term.upper())
123136
else:
124-
processed_words.append(f'"{word}"')
137+
parsed_query.append(escape_fts5_term(term.strip('"')))
138+
return ' '.join(parsed_query)
125139

126-
processed_query = ' '.join(processed_words)
127140

128-
sql = f"""
129-
video_id IN (
130-
SELECT video_id
131-
FROM Videos
132-
WHERE channel_id = '{channel_id}'
133-
)
141+
def search_channel(channel_id, text, limit=None):
142+
conn = sqlite3.connect(get_db_path())
143+
curr = conn.cursor()
144+
145+
fts5_query = parse_query(text)
146+
147+
query = """
148+
SELECT
149+
s.rowid,
150+
s.subtitle_id,
151+
s.video_id,
152+
s.start_time,
153+
s.stop_time,
154+
s.text
155+
FROM
156+
Subtitles_fts fts
157+
JOIN
158+
Subtitles s ON fts.rowid = s.rowid
159+
JOIN
160+
Videos v ON s.video_id = v.video_id
161+
WHERE
162+
fts.text MATCH ?
163+
AND v.channel_id = ?
164+
ORDER BY
165+
rank
134166
"""
167+
168+
if limit is not None:
169+
query += " LIMIT ?"
170+
curr.execute(query, (fts5_query, channel_id, limit))
171+
else:
172+
curr.execute(query, (fts5_query, channel_id))
173+
174+
res = curr.fetchall()
175+
formatted_res = []
176+
for row in res:
177+
formatted_res.append({
178+
"rowid": row[0],
179+
"subtitle_id": row[1],
180+
"video_id": row[2],
181+
"start_time": row[3],
182+
"stop_time": row[4],
183+
"text": row[5]
184+
})
185+
conn.close()
135186

136-
return list(db["Subtitles"].search(processed_query,
137-
where=sql,
138-
limit=limit))
187+
return formatted_res
139188

140189

141190
def search_video(video_id, text, limit=None):
142-
db = Database(get_db_path())
143-
144-
return list(db["Subtitles"].search(text,
145-
where=f"video_id = '{video_id}'",
146-
limit=limit))
191+
try:
192+
conn = sqlite3.connect(get_db_path())
193+
curr = conn.cursor()
194+
195+
fts5_query = parse_query(text)
196+
sql = """
197+
SELECT
198+
s.rowid,
199+
s.subtitle_id,
200+
s.video_id,
201+
s.start_time,
202+
s.stop_time,
203+
s.text
204+
FROM
205+
Subtitles_fts fts
206+
JOIN
207+
Subtitles s ON fts.rowid = s.rowid
208+
WHERE
209+
s.video_id = ?
210+
AND
211+
fts.text MATCH ?
212+
"""
213+
214+
if limit is not None:
215+
sql += " LIMIT ?"
216+
curr.execute(sql, (video_id, fts5_query, limit))
217+
else:
218+
curr.execute(sql, (video_id, fts5_query))
219+
220+
res = curr.fetchall()
221+
222+
formatted_res = []
223+
224+
for row in res:
225+
formatted_res.append({
226+
"rowid": row[0],
227+
"subtitle_id": row[1],
228+
"video_id": row[2],
229+
"start_time": row[3],
230+
"stop_time": row[4],
231+
"text": row[5]
232+
})
233+
234+
conn.close()
235+
return formatted_res
236+
237+
except Exception as e:
238+
print(e)
239+
sys.exit(1)
240+
finally:
241+
conn.close()
147242

148243

149244
def search_all(text, limit=None):
150-
db = Database(get_db_path())
151-
152-
words = text.split()
153-
processed_words = []
154-
for word in words:
155-
if '"' in word:
156-
processed_words.append(word.replace('"', '""'))
245+
try:
246+
conn = sqlite3.connect(get_db_path())
247+
curr = conn.cursor()
248+
fts5_query = parse_query(text)
249+
250+
sql = """
251+
SELECT
252+
s.rowid,
253+
s.subtitle_id,
254+
s.video_id,
255+
s.start_time,
256+
s.stop_time,
257+
s.text
258+
FROM
259+
Subtitles_fts fts
260+
JOIN
261+
Subtitles s ON fts.rowid = s.rowid
262+
WHERE
263+
fts.text MATCH ?
264+
ORDER BY
265+
rank
266+
"""
267+
268+
if limit is not None:
269+
sql += " LIMIT ?"
270+
curr.execute(sql, (fts5_query, limit))
157271
else:
158-
processed_words.append(f'"{word}"')
272+
curr.execute(sql, (fts5_query,))
159273

160-
processed_query = ' '.join(processed_words)
161274

162-
return list(db["Subtitles"].search(processed_query, limit=limit))
275+
res = curr.fetchall()
276+
277+
formatted_res = []
278+
279+
for row in res:
280+
formatted_res.append({
281+
"rowid": row[0],
282+
"subtitle_id": row[1],
283+
"video_id": row[2],
284+
"start_time": row[3],
285+
"stop_time": row[4],
286+
"text": row[5]
287+
})
288+
289+
conn.close()
290+
return formatted_res
291+
292+
except Exception as e:
293+
print(e)
294+
sys.exit(1)
295+
296+
finally:
297+
conn.close()
163298

164299

165300
def get_title_from_db(video_id):

yt_fts/search.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,11 @@ def fts_search(text, scope, channel_id=None, video_id=None, limit=None):
3333
res = search_video(video_id, text, limit)
3434

3535
if len(res) == 0:
36-
console.print("- Try shortening the search to specific words")
37-
console.print("- Try using the wildcard operator [bold]*[/bold] to search for partial words")
38-
console.print("- Try using the [bold]OR[/bold] operator to search for multiple words")
39-
if len(text.split(" ")) > 1:
40-
example_or = text.replace(" ", " OR ")
41-
console.print(f" - EX: \"[bold]{example_or}[/bold]\"")
42-
else:
43-
console.print(f" - EX: \"[bold]foo OR [bold]bar[/bold]\"")
36+
console.print(f"[yellow]No matches found[/yellow]\n"
37+
"- Try shortening the search to specific words\n"
38+
"- Try using the wildcard operator [bold]*[/bold] to search for partial words\n"
39+
"- Try using the [bold]OR[/bold] operator to search for multiple words\n"
40+
" - EX: \"foo OR bar\"")
4441
sys.exit(1)
4542

4643
return res
@@ -147,7 +144,8 @@ def print_fts_res(res, query):
147144
num_channels = len(set(channel_names))
148145
num_videos = len(set([quote["video_id"] for quote in res]))
149146

150-
summary_str = f"Found [bold]{num_matches}[/bold] matches in [bold]{num_videos}[/bold] videos from [bold]{num_channels}[/bold] channel"
147+
summary_str = f"Found [bold]{num_matches}[/bold] matches in [bold]{num_videos}[/bold] "
148+
summary_str += f"videos from [bold]{num_channels}[/bold] channel"
151149

152150
if num_channels > 1:
153151
summary_str += "s"

0 commit comments

Comments
 (0)