diff --git a/.github/workflows/build_image.yml b/.github/workflows/build_image.yml new file mode 100644 index 0000000..475763b --- /dev/null +++ b/.github/workflows/build_image.yml @@ -0,0 +1,41 @@ +name: DockerBuildAndPush + +on: + push: + branches: + - master + - developement + - ptb-async + +env: + IMAGE_NAME: transcriberbot + +jobs: + push: + runs-on: ubuntu-latest + if: github.event_name == 'push' + + steps: + - uses: actions/checkout@v2 + + - name: Login to ghcr registry + run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u $ --password-stdin + + - name: Build image + run: docker build . --file Dockerfile --tag $IMAGE_NAME + + - name: Push image + run: | + IMAGE_ID=ghcr.io/${{ github.repository_owner }}/$IMAGE_NAME + # Change all uppercase to lowercase + IMAGE_ID=$(echo $IMAGE_ID | tr '[A-Z]' '[a-z]') + # Strip git ref prefix from version + VERSION=$(echo "${{ github.ref }}" | sed -e 's,.*/\(.*\),\1,') + # Strip "v" prefix from tag name + [[ "${{ github.ref }}" == "refs/tags/"* ]] && VERSION=$(echo $VERSION | sed -e 's/^v//') + # Use Docker `latest` tag convention + [ "$VERSION" == "master" ] && VERSION=latest + echo IMAGE_ID=$IMAGE_ID + echo VERSION=$VERSION + docker tag $IMAGE_NAME $IMAGE_ID:$VERSION + docker push $IMAGE_ID:$VERSION \ No newline at end of file diff --git a/.gitignore b/.gitignore index 2547aad..39ace71 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ # TranscriberBot-specific ignores -config/ media/ # Generic data-related ignores diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..275cfdd --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +transcriber-bot-wonda diff --git a/Dockerfile b/Dockerfile index e42b32f..4e882af 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9-slim +FROM python:3.12.0-slim # Set global configs WORKDIR / diff --git a/dockerBuild.sh b/build.sh similarity index 100% rename from dockerBuild.sh rename to build.sh diff --git a/config/app.json b/config/app.json index 3eba13a..fad2d12 100644 --- a/config/app.json +++ b/config/app.json @@ -49,6 +49,10 @@ "webm" ], + "ocr": { + "tesseract_path": "/usr/share/tesseract-ocr/5/tessdata/" + }, + "antiflood": { "age_threshold": 10, "flood_ratio": 2, @@ -56,5 +60,13 @@ "time_threshold_warning": 4, "time_threshold_flood": 5, "timeout": 10 + }, + + "whisper": { + "api_endpoint": "http://127.0.0.1:8000" + }, + + "logging": { + "level": "APP" } } diff --git a/config/sentry.json b/config/sentry.json new file mode 100644 index 0000000..cb8ef39 --- /dev/null +++ b/config/sentry.json @@ -0,0 +1,3 @@ +{ + "dsn": "xxx" +} \ No newline at end of file diff --git a/dockerRun.sh b/dockerRun.sh deleted file mode 100755 index 3826e74..0000000 --- a/dockerRun.sh +++ /dev/null @@ -1,8 +0,0 @@ -docker run \ - -e LC_ALL=C \ - -d --restart unless-stopped \ - --name "transcriberbot" \ - -v "$(pwd)"/data:/data \ - -v "$(pwd)"/config:/config \ - -v "$(pwd)"/values:/values \ - transcriberbot diff --git a/requirements.txt b/requirements.txt index b29f863..8b8bc35 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -python-telegram-bot==12.3.0 +python-telegram-bot coloredlogs pillow watchdog @@ -6,3 +6,4 @@ tesserocr pydub zbarlight requests +sentry-sdk \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..0fe85f7 --- /dev/null +++ b/run.sh @@ -0,0 +1,15 @@ +#!/bin/sh + +docker pull ghcr.io/charslab/transcriberbot:ptb-async +docker run \ + -e LC_ALL=C \ + -d --restart unless-stopped \ + --name "transcriberbot-async" \ + -v "$(pwd)"/data:/data \ + -v "$(pwd)"/config:/config \ + -v "$(pwd)"/values:/values \ + -v "$(pwd)"/media:/media \ + --cpus=4.0 \ + --memory=3000m \ + -u "$(id -u):1337" \ + ghcr.io/charslab/transcriberbot:ptb-async \ No newline at end of file diff --git a/src/antiflood/__init__.py b/src/antiflood/__init__.py index f5065ae..429e8ea 100644 --- a/src/antiflood/__init__.py +++ b/src/antiflood/__init__.py @@ -2,4 +2,4 @@ from antiflood.antiflood import register_flood_warning_callback from antiflood.antiflood import register_flood_started_callback from antiflood.antiflood import register_flood_ended_callback -from antiflood.antiflood import init \ No newline at end of file +from antiflood.antiflood import init diff --git a/src/antiflood/antiflood.py b/src/antiflood/antiflood.py index 16604ef..b32256f 100644 --- a/src/antiflood/antiflood.py +++ b/src/antiflood/antiflood.py @@ -4,11 +4,11 @@ logger = logging.getLogger(__name__) -flood_ratio = 2 # messages/seconds +flood_ratio = 2 # messages/seconds max_flood_ratio = 10 time_threshold_warning = 5 # ratio > flood_ratio for {time_threshold_warning} seconds -time_threshold_flood = 10 # ratio > flood_ratio for {time_threshold_flood} seconds -timeout = 4 # flood ends after ratio < flood_ratio for {timeout} seconds +time_threshold_flood = 10 # ratio > flood_ratio for {time_threshold_flood} seconds +timeout = 4 # flood ends after ratio < flood_ratio for {timeout} seconds callback_flood_warning = None callback_flood_started = None @@ -20,74 +20,77 @@ # chat_id -> (level, ratio, msg_num, duration, last_update) stats = {} -def register_flood_warning_callback(callback): - global callback_flood_warning - callback_flood_warning = callback - -def register_flood_started_callback(callback): - global callback_flood_started - callback_flood_started = callback - -def register_flood_ended_callback(callback): - global callback_flood_ended - callback_flood_ended = callback - -def init(): - global flood_ratio, max_flood_ratio, time_threshold_warning, time_threshold_flood, timeout - flood_ratio = config.get_config_prop("app")["antiflood"]["flood_ratio"] - max_flood_ratio = config.get_config_prop("app")["antiflood"]["max_flood_ratio"] - time_threshold_warning = config.get_config_prop("app")["antiflood"]["time_threshold_warning"] - time_threshold_flood = config.get_config_prop("app")["antiflood"]["time_threshold_flood"] - timeout = config.get_config_prop("app")["antiflood"]["timeout"] - - logger.info("Ratio: %d", flood_ratio) - logger.info("Max flood ratio: %d", max_flood_ratio) - logger.info("Thr warning: %d", time_threshold_warning) - logger.info("Thr flood: %d", time_threshold_flood) - logger.info("Timeout: %d", timeout) -def on_chat_msg_received(chat_id): - global flood_ratio, time_threshold_warning, time_threshold_flood, timeout - global callback_flood_warning, callback_flood_started, callback_flood_ended - - curr_time = time.time() - - if chat_id not in stats: - stats[chat_id] = [LEVEL_NORMAL, 1.0, 1, 0.0, curr_time] +def register_flood_warning_callback(callback): + global callback_flood_warning + callback_flood_warning = callback - else: - level, ratio, msg_num, duration, last_update = stats[chat_id] - updated_duration = duration + curr_time - last_update - msg_num += 1 - curr_ratio = msg_num / updated_duration - if curr_ratio < flood_ratio and updated_duration > timeout: - curr_ratio, updated_duration, msg_num = 0, 0, 0 - level = LEVEL_NORMAL - if callback_flood_ended: - callback_flood_ended(chat_id) +def register_flood_started_callback(callback): + global callback_flood_started + callback_flood_started = callback - elif updated_duration > 1 and curr_ratio > max_flood_ratio and level < LEVEL_FLOOD: - level = LEVEL_FLOOD - logger.warning("Flood ratio for chat %d is over the top", chat_id) - if callback_flood_started: - callback_flood_started(chat_id) - elif curr_ratio > flood_ratio: - if updated_duration >= time_threshold_flood and level < LEVEL_FLOOD: - logger.warning("Flood detected for chat %d", chat_id) - level = LEVEL_FLOOD - if callback_flood_started: - callback_flood_started(chat_id) +def register_flood_ended_callback(callback): + global callback_flood_ended + callback_flood_ended = callback - elif updated_duration >= time_threshold_warning and level < LEVEL_WARNING: - logger.info("Potential flood for chat %d", chat_id) - level = LEVEL_WARNING - if callback_flood_warning is not None: - callback_flood_warning(chat_id) - stats[chat_id] = (level, curr_ratio, msg_num, updated_duration, curr_time) +def init(): + global flood_ratio, max_flood_ratio, time_threshold_warning, time_threshold_flood, timeout + flood_ratio = config.get_config_prop("app")["antiflood"]["flood_ratio"] + max_flood_ratio = config.get_config_prop("app")["antiflood"]["max_flood_ratio"] + time_threshold_warning = config.get_config_prop("app")["antiflood"]["time_threshold_warning"] + time_threshold_flood = config.get_config_prop("app")["antiflood"]["time_threshold_flood"] + timeout = config.get_config_prop("app")["antiflood"]["timeout"] - logger.info("stats[{}]: {}".format(chat_id, stats[chat_id])) + logger.info("Ratio: %d", flood_ratio) + logger.info("Max flood ratio: %d", max_flood_ratio) + logger.info("Thr warning: %d", time_threshold_warning) + logger.info("Thr flood: %d", time_threshold_flood) + logger.info("Timeout: %d", timeout) +def on_chat_msg_received(chat_id): + global flood_ratio, time_threshold_warning, time_threshold_flood, timeout + global callback_flood_warning, callback_flood_started, callback_flood_ended + + curr_time = time.time() + + if chat_id not in stats: + stats[chat_id] = [LEVEL_NORMAL, 1.0, 1, 0.0, curr_time] + + else: + level, ratio, msg_num, duration, last_update = stats[chat_id] + updated_duration = duration + curr_time - last_update + msg_num += 1 + curr_ratio = msg_num / updated_duration + + if curr_ratio < flood_ratio and updated_duration > timeout: + curr_ratio, updated_duration, msg_num = 0, 0, 0 + level = LEVEL_NORMAL + if callback_flood_ended: + callback_flood_ended(chat_id) + + elif updated_duration > 1 and curr_ratio > max_flood_ratio and level < LEVEL_FLOOD: + level = LEVEL_FLOOD + logger.warning("Flood ratio for chat %d is over the top", chat_id) + if callback_flood_started: + callback_flood_started(chat_id) + + elif curr_ratio > flood_ratio: + if updated_duration >= time_threshold_flood and level < LEVEL_FLOOD: + logger.warning("Flood detected for chat %d", chat_id) + level = LEVEL_FLOOD + if callback_flood_started: + callback_flood_started(chat_id) + + elif updated_duration >= time_threshold_warning and level < LEVEL_WARNING: + logger.info("Potential flood for chat %d", chat_id) + level = LEVEL_WARNING + if callback_flood_warning is not None: + callback_flood_warning(chat_id) + + stats[chat_id] = (level, curr_ratio, msg_num, updated_duration, curr_time) + + logger.info("stats[{}]: {}".format(chat_id, stats[chat_id])) diff --git a/src/audiotools/__init__.py b/src/audiotools/__init__.py index b63ff77..22aa887 100644 --- a/src/audiotools/__init__.py +++ b/src/audiotools/__init__.py @@ -1 +1 @@ -from audiotools.speech import transcribe \ No newline at end of file +from audiotools.speech import transcribe diff --git a/src/audiotools/speech.py b/src/audiotools/speech.py index bebfff2..188bf8b 100644 --- a/src/audiotools/speech.py +++ b/src/audiotools/speech.py @@ -1,110 +1,140 @@ import io import logging import traceback +import os +import asyncio import requests +from functools import partial + import pydub from pydub import AudioSegment - +import config +import textwrap logger = logging.getLogger("speech") class WitTranscriber: - speech_url = "https://api.wit.ai/speech" - - def __init__(self, api_key): - self.session = requests.Session() - self.session.headers.update( - { - "Authorization": "Bearer " + api_key, - "Accept": "application/vnd.wit.20180705+json", - "Content-Type": "audio/raw;encoding=signed-integer;bits=16;rate=8000;endian=little", - } - ) - - def transcribe(self, chunk): - text = None - try: - response = self.session.post( - self.speech_url, - params={"verbose": True}, - data=io.BufferedReader(io.BytesIO(chunk.raw_data)) - ) - logger.debug("Request response %s", response.text) - data = response.json() - if "_text" in data: - text = data["_text"] - elif "text" in data: # Changed in may 2020 - text = data["text"] - - except requests.exceptions.RequestException as e: - logger.error("Could not transcribe chunk: %s", traceback.format_exc()) - - return text - - def close(self): - self.session.close() - - -def __generate_chunks(segment, length=20000/1001, split_on_silence=False, noise_threshold=-36): - chunks = list() - if split_on_silence is False: - for i in range(0, len(segment), int(length*1000)): - chunks.append(segment[i:i+int(length*1000)]) - else: - while len(chunks) < 1: - logger.debug('split_on_silence (threshold %d)', noise_threshold) - chunks = pydub.silence.split_on_silence(segment, noise_threshold) - noise_threshold += 4 + speech_url = "https://api.wit.ai/speech" + + def __init__(self, api_key): + self.session = requests.Session() + self.session.headers.update( + { + "Authorization": "Bearer " + api_key, + "Accept": "application/vnd.wit.20180705+json", + "Content-Type": "audio/raw;encoding=signed-integer;bits=16;rate=8000;endian=little", + } + ) + + async def transcribe(self, chunk): + text = None + try: + loop = asyncio.get_event_loop() + response = await loop.run_in_executor( + None, partial(self.session.post, + url=self.speech_url, + params={"verbose": True}, + data=io.BufferedReader(io.BytesIO(chunk.raw_data))) + ) + + logger.debug("Request response %s", response.text) + data = response.json() + if "_text" in data: + text = data["_text"] + elif "text" in data: # Changed in may 2020 + text = data["text"] + + except requests.exceptions.RequestException as e: + logger.error("Could not transcribe chunk", exc_info=True) + + return text + + def close(self): + self.session.close() + + +def __generate_chunks(segment, length=20000 / 1001, split_on_silence=False, noise_threshold=-36): + chunks = list() + if split_on_silence is False: + for i in range(0, len(segment), int(length * 1000)): + chunks.append(segment[i:i + int(length * 1000)]) + else: + while len(chunks) < 1: + logger.debug('split_on_silence (threshold %d)', noise_threshold) + chunks = pydub.silence.split_on_silence(segment, noise_threshold) + noise_threshold += 4 + + for i, chunk in enumerate(chunks): + if len(chunk) > int(length * 1000): + subchunks = __generate_chunks(chunk, length, split_on_silence, noise_threshold + 4) + chunks = chunks[:i - 1] + subchunks + chunks[i + 1:] + + return chunks + + +def __preprocess_audio(audio): + return audio.set_sample_width(2).set_channels(1).set_frame_rate(8000) + +async def transcribe_wit(path, api_key): + logger.info("Transcribing file %s", path) + audio = AudioSegment.from_file(path) + + chunks = __generate_chunks(__preprocess_audio(audio)) + logger.debug("Got %d chunks", len(chunks)) + + transcriber = WitTranscriber(api_key) for i, chunk in enumerate(chunks): - if len(chunk) > int(length*1000): - subchunks = __generate_chunks(chunk, length, split_on_silence, noise_threshold+4) - chunks = chunks[:i-1] + subchunks + chunks[i+1:] + logger.debug("Transcribing chunk %d", i) + text = await transcriber.transcribe(chunk) + logger.debug("Response received: %s", text) - return chunks + if text is not None: + yield i, text, len(chunks) + transcriber.close() -def __preprocess_audio(audio): - return audio.set_sample_width(2).set_channels(1).set_frame_rate(8000) +async def transcribe_whisper(path): + resp = requests.get(f"{config.get_config_prop('app')['whisper']['api_endpoint']}/transcribe?file_id={path}") + + # split the response into chunks of 4000 characters + chunks = textwrap.wrap(resp.text, 4000) + for idx, chunk in enumerate(chunks): + yield idx, chunk, len(chunks) -def transcribe(path, api_key): - logger.info("Transcribing file %s", path) - audio = AudioSegment.from_file(path) - chunks = __generate_chunks(__preprocess_audio(audio)) - logger.debug("Got %d chunks", len(chunks)) +def transcribe(path, api_key, backend="wit"): + if backend == "wit": + logging.debug("Transcribing with wit") + return transcribe_wit(path, api_key) - transcriber = WitTranscriber(api_key) - for i, chunk in enumerate(chunks): - logger.debug("Transcribing chunk %d", i) - text = transcriber.transcribe(chunk) - logger.debug("Response received: %s", text) + elif backend == "whisper": + logging.debug("Transcribing with whisper") + return transcribe_whisper(os.path.basename(path)) - if text is not None: - yield text - transcriber.close() + raise ValueError("Unknown backend: %s" % backend) if __name__ == "__main__": - import argparse - import sys - - parser = argparse.ArgumentParser() - parser.add_argument("api_key") - parser.add_argument("input_filename") - parser.add_argument("output_filename") - args = parser.parse_args() - - if args.output_filename == "-": - output = sys.stdout - else: - output = open(args.output_filename, mode="w") - - result = transcribe(args.input_filename, args.api_key) - for part in result: - output.write(part + "\n") - output.flush() - - output.close() + import argparse + import sys + + parser = argparse.ArgumentParser() + parser.add_argument("api_key") + parser.add_argument("input_filename") + parser.add_argument("output_filename") + args = parser.parse_args() + + if args.output_filename == "-": + output = sys.stdout + else: + output = open(args.output_filename, mode="w") + + result = transcribe(args.input_filename, args.api_key) + for part, tot in result: + output.write(part + "\n") + output.flush() + + output.close() diff --git a/src/config/__init__.py b/src/config/__init__.py index c71e460..9986c2a 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -2,39 +2,69 @@ import json import functional import logging - import pprint +APP_LOG = 25 + logger = logging.getLogger(__name__) __configs = {} + def parse_file(file): - logger.info("Loading config file %s", file) + logger.info("Loading config file %s", file) + + with open(file) as f: + data = json.load(f) + return data - with open(file) as f: - data = json.load(f) - return data def init(config_folder): - global __configs - files = glob.glob(os.path.join(config_folder, "*.json")) + global __configs + files = glob.glob(os.path.join(config_folder, "*.json")) + + keys = [x.replace(config_folder, "").replace(".json", "").replace("/", "") for x in files] + configs = map(parse_file, files) + __configs = dict(zip(keys, configs)) - keys = [x.replace(config_folder, "").replace(".json", "").replace("/", "") for x in files] - configs = map(parse_file, files) - __configs = dict(zip(keys, configs)) + base = os.path.join(os.path.dirname(__file__), "../../") - base = os.path.join(os.path.dirname(__file__), "../../") + if not os.path.isabs(__configs['app']['database']): + __configs['app']['database'] = os.path.join(base, __configs['app']['database']) - if not os.path.isabs(__configs['app']['database']): - __configs['app']['database'] = os.path.join(base, __configs['app']['database']) + if not os.path.isabs(__configs['app']['media_path']): + __configs['app']['media_path'] = os.path.join(base, __configs['app']['media_path']) - if not os.path.isabs(__configs['app']['media_path']): - __configs['app']['media_path'] = os.path.join(base, __configs['app']['media_path']) + if not os.path.isdir(__configs['app']['media_path']): + os.mkdir(__configs['app']['media_path']) - if not os.path.isdir(__configs['app']['media_path']): - os.mkdir(__configs['app']['media_path']) def get_config_prop(key): - global __configs - return __configs[key] + global __configs + return __configs[key] + + +def bot_token(): + return get_config_prop("telegram")["token"] + + +def get_language_list(): + return get_config_prop("app")["languages"].keys() + + +def get_audio_extensions(): + return get_config_prop("app").get("audio_ext", []) + + +def get_video_extensions(): + return get_config_prop("app").get("video_ext", []) + + +def get_document_extensions(): + audio_ext = get_audio_extensions() + video_ext = get_video_extensions() + return audio_ext + video_ext + + +def get_bot_admins(): + return [int(id) for id in get_config_prop("telegram")["admins"]] diff --git a/src/database/__init__.py b/src/database/__init__.py index 526f15f..7a23c2d 100644 --- a/src/database/__init__.py +++ b/src/database/__init__.py @@ -16,22 +16,25 @@ """ + def init_schema(database): - with Database(database) as db: - db.execute("""CREATE TABLE IF NOT EXISTS chats ( - chat_id INTEGER PRIMARY KEY, - lang VARCHAR(5) NOT NULL, - voice_enabled INTEGER, - photos_enabled INTEGER, - qr_enabled INTEGER, - active INTEGER, - ban INTEGER) - """) + with Database(database) as db: + db.execute( + "CREATE TABLE IF NOT EXISTS chats (" + "chat_id INTEGER PRIMARY KEY, " + "lang VARCHAR(5) NOT NULL, " + "voice_enabled INTEGER," + "photos_enabled INTEGER," + "qr_enabled INTEGER," + "active INTEGER," + "ban INTEGER)" + ) - db.execute("""CREATE TABLE IF NOT EXISTS stats ( - month_year INTEGER PRIMARY KEY, - audio_num INTEGER, - min_tot_audio INTEGER, - min_transcribed_audio INTEGER, - num_pictures INTEGER) - """) \ No newline at end of file + db.execute( + "CREATE TABLE IF NOT EXISTS stats (" + "month_year INTEGER PRIMARY KEY," + "audio_num INTEGER, " + "min_tot_audio INTEGER," + "min_transcribed_audio INTEGER," + "num_pictures INTEGER)" + ) diff --git a/src/database/db.py b/src/database/db.py index 4a92579..a928ab0 100644 --- a/src/database/db.py +++ b/src/database/db.py @@ -6,158 +6,156 @@ logger = logging.getLogger(__name__) + class Database(): - __instance = None + __instance = None - def __init__(self, database): - self.database = database + def __init__(self, database): + self.database = database - def __connect(self): - self.__connection = sqlite3.connect(self.database) - self.__cursor = self.__connection.cursor() + def __connect(self): + self.__connection = sqlite3.connect(self.database) + self.__cursor = self.__connection.cursor() - def __close(self): - self.__connection.commit() - self.__connection.close() + def __close(self): + self.__connection.commit() + self.__connection.close() - def __enter__(self): - logger.debug("__enter__") - self.__connect() - return self + def __enter__(self): + logger.debug("__enter__") + self.__connect() + return self - def assoc(self): - self.__connection.row_factory = sqlite3.Row - self.__cursor = self.__connection.cursor() + def assoc(self): + self.__connection.row_factory = sqlite3.Row + self.__cursor = self.__connection.cursor() + def __exit__(self, exc_type, exc_value, exc_traceback): + logger.debug("__exit__") + self.__close() - def __exit__(self, exc_type, exc_value, exc_traceback): - logger.debug("__exit__") - self.__close() + if exc_type: + logger.error("exc_type: {}".format(exc_type)) + logger.error("exc_value: {}".format(exc_value)) + logger.error("exc_traceback: {}".format(exc_traceback)) + logger.error("Caught exception", exc_info=True) - if exc_type: - logger.error("exc_type: {}".format(exc_type)) - logger.error("exc_value: {}".format(exc_value)) - logger.error("exc_traceback: {}".format(exc_traceback)) - logger.error(traceback.format_exc()) + return True - return True + def execute(self, query, *args): + res = self.__cursor.execute(query, *args) + return self.__cursor - def execute(self, query, *args): - res = self.__cursor.execute(query, *args) - return self.__cursor class TBDB(): - @staticmethod - def _get_db(): - return Database(config.get_config_prop("app")["database"]) - - - @staticmethod - def create_default_chat_entry(chat_id, lang): - with TBDB._get_db() as db: - db.execute( - "INSERT INTO chats(chat_id, lang, voice_enabled, photos_enabled, qr_enabled, active, ban) VALUES(?,?,?,?,?,?,?)", - (chat_id, lang, 1, 0, 0, 1, 0) - ) - - @staticmethod - def get_chat_entry(chat_id): - with TBDB._get_db() as db: - db.assoc() - cursor = db.execute("SELECT * FROM chats WHERE chat_id='{0}'".format(chat_id)) - return cursor.fetchone() - - @staticmethod - def get_chats(): - with TBDB._get_db() as db: - db.assoc() - cursor = db.execute("SELECT * FROM chats") - return [dict(x) for x in cursor.fetchall()] - - @staticmethod - def get_chat_lang(chat_id): - chat_record = TBDB.get_chat_entry(chat_id) - if not chat_record: - logger.debug("Record for chat {} not found, creating one.".format(chat_id)) - TBDB.create_default_chat_entry(chat_id, "en-US") - - with TBDB._get_db() as db: - cursor = db.execute("SELECT lang FROM chats WHERE chat_id='{0}'".format(chat_id)) - return cursor.fetchone()[0] - - @staticmethod - def set_chat_lang(chat_id, lang): - with TBDB._get_db() as db: - db.execute("UPDATE chats SET lang='{0}' WHERE chat_id='{1}'".format(lang, chat_id)) - - - @staticmethod - def get_chat_voice_enabled(chat_id): - with TBDB._get_db() as db: - c = db.execute("SELECT voice_enabled FROM chats WHERE chat_id='{0}'".format(chat_id)) - return c.fetchone()[0] - - @staticmethod - def set_chat_voice_enabled(chat_id, voice_enabled): - with TBDB._get_db() as db: - db.execute("UPDATE chats SET voice_enabled='{0}' WHERE chat_id='{1}'".format(voice_enabled, chat_id)) - - - @staticmethod - def get_chat_photos_enabled(chat_id): - with TBDB._get_db() as db: - c = db.execute("SELECT photos_enabled FROM chats WHERE chat_id='{0}'".format(chat_id)) - return c.fetchone()[0] - - @staticmethod - def set_chat_photos_enabled(chat_id, photos_enabled): - with TBDB._get_db() as db: - db.execute("UPDATE chats SET photos_enabled='{0}' WHERE chat_id='{1}'".format(photos_enabled, chat_id)) - - - @staticmethod - def get_chat_qr_enabled(chat_id): - with TBDB._get_db() as db: - c = db.execute("SELECT qr_enabled FROM chats WHERE chat_id='{0}'".format(chat_id)) - return c.fetchone()[0] - - @staticmethod - def set_chat_qr_enabled(chat_id, qr_enabled): - with TBDB._get_db() as db: - db.execute("UPDATE chats SET qr_enabled='{0}' WHERE chat_id='{1}'".format(qr_enabled, chat_id)) - - @staticmethod - def get_chat_active(chat_id): - with TBDB._get_db() as db: - c = db.execute("SELECT active FROM chats WHERE chat_id='{0}'".format(chat_id)) - return c.fetchone()[0] - - @staticmethod - def set_chat_active(chat_id, active): - with TBDB._get_db() as db: - db.execute("UPDATE chats SET active='{0}' WHERE chat_id='{1}'".format(active, chat_id)) - - @staticmethod - def get_chat_ban(chat_id): - with TBDB._get_db() as db: - c = db.execute("SELECT ban FROM chats WHERE chat_id='{0}'".format(chat_id)) - return c.fetchone()[0] - - @staticmethod - def set_chat_ban(chat_id, ban): - with TBDB._get_db() as db: - db.execute("UPDATE chats SET ban='{0}' WHERE chat_id='{1}'".format(ban, chat_id)) - - - @staticmethod - def get_chats_num(): - with TBDB._get_db() as db: - c = db.execute("SELECT count(*) FROM chats") - return int(c.fetchone()[0]) - - @staticmethod - def get_active_chats_num(): - with TBDB._get_db() as db: - c = db.execute("SELECT count(*) FROM chats where active=1") - return int(c.fetchone()[0]) - + @staticmethod + def _get_db(): + return Database(config.get_config_prop("app")["database"]) + + @staticmethod + def create_default_chat_entry(chat_id, lang): + with TBDB._get_db() as db: + db.execute( + "INSERT INTO chats(chat_id, lang, voice_enabled, photos_enabled, qr_enabled, active, ban) VALUES(?,?,?,?,?,?,?)", + (chat_id, lang, 1, 0, 0, 1, 0) + ) + + @staticmethod + def get_chat_entry(chat_id): + with TBDB._get_db() as db: + db.assoc() + cursor = db.execute("SELECT * FROM chats WHERE chat_id='{0}'".format(chat_id)) + return cursor.fetchone() + + @staticmethod + def get_chats(): + with TBDB._get_db() as db: + db.assoc() + cursor = db.execute("SELECT * FROM chats") + return [dict(x) for x in cursor.fetchall()] + + @staticmethod + def get_chat_lang(chat_id): + chat_record = TBDB.get_chat_entry(chat_id) + if not chat_record: + logger.debug("Record for chat {} not found, creating one.".format(chat_id)) + TBDB.create_default_chat_entry(chat_id, "en-US") + return "en-US" + + return chat_record["lang"] + + @staticmethod + def set_chat_lang(chat_id, lang): + with TBDB._get_db() as db: + db.execute("UPDATE chats SET lang='{0}' WHERE chat_id='{1}'".format(lang, chat_id)) + + @staticmethod + def get_chat_voice_enabled(chat_id): + try: + with TBDB._get_db() as db: + c = db.execute("SELECT voice_enabled FROM chats WHERE chat_id='{0}'".format(chat_id)) + return c.fetchone()[0] + except TypeError as e: + logger.error("Error getting voice_enabled for chat %d: %s", chat_id, e) + raise e + + @staticmethod + def set_chat_voice_enabled(chat_id, voice_enabled): + with TBDB._get_db() as db: + db.execute("UPDATE chats SET voice_enabled='{0}' WHERE chat_id='{1}'".format(voice_enabled, chat_id)) + + @staticmethod + def get_chat_photos_enabled(chat_id): + with TBDB._get_db() as db: + c = db.execute("SELECT photos_enabled FROM chats WHERE chat_id='{0}'".format(chat_id)) + return c.fetchone()[0] + + @staticmethod + def set_chat_photos_enabled(chat_id, photos_enabled): + with TBDB._get_db() as db: + db.execute("UPDATE chats SET photos_enabled='{0}' WHERE chat_id='{1}'".format(photos_enabled, chat_id)) + + @staticmethod + def get_chat_qr_enabled(chat_id): + with TBDB._get_db() as db: + c = db.execute("SELECT qr_enabled FROM chats WHERE chat_id='{0}'".format(chat_id)) + return c.fetchone()[0] + + @staticmethod + def set_chat_qr_enabled(chat_id, qr_enabled): + with TBDB._get_db() as db: + db.execute("UPDATE chats SET qr_enabled='{0}' WHERE chat_id='{1}'".format(qr_enabled, chat_id)) + + @staticmethod + def get_chat_active(chat_id): + with TBDB._get_db() as db: + c = db.execute("SELECT active FROM chats WHERE chat_id='{0}'".format(chat_id)) + return c.fetchone()[0] + + @staticmethod + def set_chat_active(chat_id, active): + with TBDB._get_db() as db: + db.execute("UPDATE chats SET active='{0}' WHERE chat_id='{1}'".format(active, chat_id)) + + @staticmethod + def get_chat_ban(chat_id): + with TBDB._get_db() as db: + c = db.execute("SELECT ban FROM chats WHERE chat_id='{0}'".format(chat_id)) + return c.fetchone()[0] + + @staticmethod + def set_chat_ban(chat_id, ban): + with TBDB._get_db() as db: + db.execute("UPDATE chats SET ban='{0}' WHERE chat_id='{1}'".format(ban, chat_id)) + + @staticmethod + def get_chats_num(): + with TBDB._get_db() as db: + c = db.execute("SELECT count(*) FROM chats") + return int(c.fetchone()[0]) + + @staticmethod + def get_active_chats_num(): + with TBDB._get_db() as db: + c = db.execute("SELECT count(*) FROM chats where active=1") + return int(c.fetchone()[0]) diff --git a/src/functional/__init__.py b/src/functional/__init__.py index e1a5fde..6a686cc 100644 --- a/src/functional/__init__.py +++ b/src/functional/__init__.py @@ -1,3 +1,3 @@ def apply_fn(list, fn): - for item in list: - fn(item) \ No newline at end of file + for item in list: + fn(item) diff --git a/src/main.py b/src/main.py index 7bb992b..971368d 100644 --- a/src/main.py +++ b/src/main.py @@ -1,25 +1,49 @@ -import coloredlogs,logging +import logging + import config import resources import database import antiflood +import transcriberbot.bot +import sentry_sdk +from sentry_sdk.integrations.asyncio import AsyncioIntegration -from telegram.ext import Filters +def main(): + config.init('../config') -import transcriberbot -from transcriberbot import TranscriberBot + logging.addLevelName(config.APP_LOG, "APP") -coloredlogs.install( - level='DEBUG', - fmt='%(asctime)s - %(name)s - %(levelname)s - %(filename)s [%(funcName)s:%(lineno)d] - %(message)s' -) -logger = logging.getLogger(__name__) + log_level = config.get_config_prop("app")["logging"]["level"] + logging.basicConfig( + format='%(asctime)s - %(name)s - %(levelname)s - %(filename)s [%(funcName)s:%(lineno)d] - %(message)s', + level=log_level + ) + logging.log(config.APP_LOG, "Setting log level to %s", log_level) -if __name__ == '__main__': - config.init('../config') - resources.init("../values") - antiflood.init() - transcriberbot.init() - database.init_schema(config.get_config_prop("app")["database"]) + resources.init("../values") + antiflood.init() + database.init_schema(config.get_config_prop("app")["database"]) + + sentry_sdk.init( + dsn=config.get_config_prop("sentry")["dsn"], + # Add data like request headers and IP for users, if applicable; + # see https://docs.sentry.io/platforms/python/data-management/data-collected/ for more info + send_default_pii=True, + # Set traces_sample_rate to 1.0 to capture 100% + # of transactions for tracing. + traces_sample_rate=1.0, + # Set profiles_sample_rate to 1.0 to profile 100% + # of sampled transactions. + # We recommend adjusting this value in production. + profiles_sample_rate=1.0, + integrations=[ + AsyncioIntegration(), + ], + ) - TranscriberBot.get().start(config.get_config_prop("telegram")["token"]) \ No newline at end of file + sentry_sdk.profiler.start_profiler() + transcriberbot.bot.run(config.bot_token()) + + +if __name__ == '__main__': + main() diff --git a/src/metaclass/__init__.py b/src/metaclass/__init__.py index 81c9a72..657df0c 100644 --- a/src/metaclass/__init__.py +++ b/src/metaclass/__init__.py @@ -1 +1 @@ -from metaclass.singleton import Singleton \ No newline at end of file +from metaclass.singleton import Singleton diff --git a/src/metaclass/singleton.py b/src/metaclass/singleton.py index d199fb8..162eb15 100644 --- a/src/metaclass/singleton.py +++ b/src/metaclass/singleton.py @@ -1,8 +1,8 @@ class Singleton(type): - _instances = {} + _instances = {} - def __call__(cls, *args, **kwargs): - k = (cls, args) - if k not in cls._instances: - cls._instances[k] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[k] \ No newline at end of file + def __call__(cls, *args, **kwargs): + k = (cls, args) + if k not in cls._instances: + cls._instances[k] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[k] diff --git a/src/phototools/__init__.py b/src/phototools/__init__.py index 3fe0f0d..6b2f26f 100644 --- a/src/phototools/__init__.py +++ b/src/phototools/__init__.py @@ -1,2 +1,2 @@ from phototools.ocr import image_ocr -from phototools.qr import read_qr \ No newline at end of file +from phototools.qr import read_qr diff --git a/src/phototools/ocr.py b/src/phototools/ocr.py index c682ac7..163ac13 100644 --- a/src/phototools/ocr.py +++ b/src/phototools/ocr.py @@ -2,13 +2,41 @@ from tesserocr import PyTessBaseAPI +import config + logger = logging.getLogger(__name__) + def image_ocr(path, lang): - logger.info("opening %s", path) + return image_ocr_tesserocr(path, lang) + + +def image_ocr_docts(path, lang): + from doctr.models import ocr_predictor + + predictor = ocr_predictor.create_predictor() + + # Perform OCR on the image + predictor(path) + + +def image_ocr_easyocr(path, lang): + import easyocr + + logger.info("opening %s", path) + + reader = easyocr.Reader(['en'], gpu=False) + result = reader.readtext(path) + text = " ".join([x[1] for x in result]) + + return text + + +def image_ocr_tesserocr(path, lang): + logger.info("opening %s", path) - with PyTessBaseAPI() as api: - api.SetImageFile(path) - text = api.GetUTF8Text().strip() + with PyTessBaseAPI(path=config.get_config_prop("app")["ocr"]["tesseract_path"]) as api: + api.SetImageFile(path) + text = api.GetUTF8Text().strip() - return text \ No newline at end of file + return text diff --git a/src/phototools/qr.py b/src/phototools/qr.py index 5c1e8a3..605f815 100644 --- a/src/phototools/qr.py +++ b/src/phototools/qr.py @@ -4,15 +4,15 @@ logger = logging.getLogger(__name__) -def read_qr(path, lang): - logger.info("opening %s", path) - with open(path, 'rb') as f: - image = Image.open(f) - image.load() - qr = zbarlight.scan_codes('qrcode', image) - if qr is not None: - qr = qr[0].decode("utf-8") +def read_qr(path): + logger.info("opening %s", path) - return qr + with open(path, 'rb') as f: + image = Image.open(f) + image.load() + qr = zbarlight.scan_codes('qrcode', image) + if qr is not None: + qr = qr[0].decode("utf-8") + return qr diff --git a/src/resources/__init__.py b/src/resources/__init__.py index 98505c5..4aab8cd 100644 --- a/src/resources/__init__.py +++ b/src/resources/__init__.py @@ -1 +1 @@ -from resources.loader import init, get_string_resource, iso639_2_to_639_1 \ No newline at end of file +from resources.loader import init, get_string_resource, iso639_2_to_639_1 diff --git a/src/resources/loader.py b/src/resources/loader.py index cb600ef..c2534c9 100644 --- a/src/resources/loader.py +++ b/src/resources/loader.py @@ -11,65 +11,72 @@ strings_r = {} __resources_directory = None + class EventHandler(FileSystemEventHandler): - @staticmethod - def on_any_event(event): - if event.event_type == "modified" or event.event_type == "created": - logger.info("Reloading resource folder") - load_config() + @staticmethod + def on_any_event(event): + if event.event_type == "modified" or event.event_type == "created": + logger.info("Reloading resource folder") + load_config() + def install_observer(): - handler = EventHandler() - observer = Observer() - observer.schedule(handler, __resources_directory) - observer.start() + handler = EventHandler() + observer = Observer() + observer.schedule(handler, __resources_directory) + observer.start() + def _load_xml_resouce(path): - logger.info("Loading resource %s", path) + logger.info("Loading resource %s", path) + + e = ElementTree.parse(path).getroot() + lang = e.get('lang') + if lang not in strings_r: + strings_r[lang] = {} - e = ElementTree.parse(path).getroot() - lang = e.get('lang') - if lang not in strings_r: - strings_r[lang] = {} + replacements = (('{b}', ''), ('{/b}', ''), + ('{i}', ''), ('{/i}', ''), + ('{code}', ''), ('{/code}', '')) - replacements = (('{b}', ''), ('{/b}', ''), - ('{i}', ''), ('{/i}', ''), - ('{code}', ''), ('{/code}', '')) + for string in e.findall('string'): + if string.text is None: + continue - for string in e.findall('string'): - if string.text is None: - continue + value = functools.reduce(lambda s, kv: s.replace(*kv), replacements, string.text) + value = value.strip() + strings_r[lang][string.get('name')] = value + logger.debug("Loaded string resource [%s] (%s): %s", string.get('name'), lang, value) - value = functools.reduce(lambda s, kv: s.replace(*kv), replacements, string.text) - value = value.strip() - strings_r[lang][string.get('name')] = value - logger.debug("Loaded string resource [%s] (%s): %s", string.get('name'), lang, value) def load_config(): - files = glob.glob(os.path.join(__resources_directory, "strings*.xml")) - functional.apply_fn(files, _load_xml_resouce) + files = glob.glob(os.path.join(__resources_directory, "strings*.xml")) + functional.apply_fn(files, _load_xml_resouce) + def init(values_folder): - global __resources_directory - __resources_directory = values_folder + global __resources_directory + __resources_directory = values_folder + + load_config() + install_observer() - load_config() - install_observer() def iso639_2_to_639_1(lang): - # Convert ISO 639-2 to 639-1 based on available translations (i.e it -> it-IT) - return next(iter(list(filter(lambda s: s.startswith(lang), strings_r.keys()))), "en-US") + # Convert ISO 639-2 to 639-1 based on available translations (i.e it -> it-IT) + return next(iter(list(filter(lambda s: s.startswith(lang), strings_r.keys()))), "en-US") + def get_string_resource(id, lang=None): - global strings_r + global strings_r - if lang is not None and len(lang) < 5: - lang = iso639_2_to_639_1(lang) + if lang is not None and len(lang) < 5: + lang = iso639_2_to_639_1(lang) - rr = None - if lang in strings_r and id in strings_r[lang]: - rr = strings_r[lang][id] - elif id in strings_r['default']: - rr = strings_r['default'][id] + rr = None + if lang in strings_r and id in strings_r[lang]: + rr = strings_r[lang][id] + elif id in strings_r['default']: + rr = strings_r['default'][id] - return rr + return rr diff --git a/src/tests/test_db.py b/src/tests/test_db.py index fcf2bb4..461aa2d 100644 --- a/src/tests/test_db.py +++ b/src/tests/test_db.py @@ -1,35 +1,39 @@ import sys, os + sys.path.append(os.path.abspath(os.path.join('.', 'src'))) import config import database from database import TBDB + def setup_function(function): - config.init(os.path.abspath('config')) - config.get_config_prop("app")["database"] = "tmp.db" - database.init_schema(config.get_config_prop("app")["database"]) + config.init(os.path.abspath('config')) + config.get_config_prop("app")["database"] = "tmp.db" + database.init_schema(config.get_config_prop("app")["database"]) + def teardown_function(function): - os.remove(config.get_config_prop("app")["database"]) + os.remove(config.get_config_prop("app")["database"]) + def test_db(): - id = 1234 - - TBDB.create_default_chat_entry(id, 'en-US') - assert TBDB.get_chat_lang(id) == 'en-US' - assert TBDB.get_chat_active(id) == 1 - - TBDB.set_chat_lang(id, 'lang') - TBDB.set_chat_voice_enabled(id, 2) - TBDB.set_chat_photos_enabled(id, 1) - TBDB.set_chat_qr_enabled(id, 1) - TBDB.set_chat_active(id, 0) - TBDB.set_chat_ban(id, 1) - - assert TBDB.get_chat_lang(id) == 'lang' - assert TBDB.get_chat_voice_enabled(id) == 2 - assert TBDB.get_chat_photos_enabled(id) == 1 - assert TBDB.get_chat_qr_enabled(id) == 1 - assert TBDB.get_chat_active(id) == 0 - assert TBDB.get_chat_ban(id) == 1 \ No newline at end of file + id = 1234 + + TBDB.create_default_chat_entry(id, 'en-US') + assert TBDB.get_chat_lang(id) == 'en-US' + assert TBDB.get_chat_active(id) == 1 + + TBDB.set_chat_lang(id, 'lang') + TBDB.set_chat_voice_enabled(id, 2) + TBDB.set_chat_photos_enabled(id, 1) + TBDB.set_chat_qr_enabled(id, 1) + TBDB.set_chat_active(id, 0) + TBDB.set_chat_ban(id, 1) + + assert TBDB.get_chat_lang(id) == 'lang' + assert TBDB.get_chat_voice_enabled(id) == 2 + assert TBDB.get_chat_photos_enabled(id) == 1 + assert TBDB.get_chat_qr_enabled(id) == 1 + assert TBDB.get_chat_active(id) == 0 + assert TBDB.get_chat_ban(id) == 1 diff --git a/src/transcriberbot/blueprints/__init__.py b/src/transcriberbot/blueprints/__init__.py new file mode 100644 index 0000000..7fc3b7e --- /dev/null +++ b/src/transcriberbot/blueprints/__init__.py @@ -0,0 +1,5 @@ +""" +Author: Carlo Alberto Barbano +Date: 15/02/25 +""" +from . import commands, messages, voice, photos, chat_handlers diff --git a/src/transcriberbot/blueprints/chat_handlers.py b/src/transcriberbot/blueprints/chat_handlers.py new file mode 100644 index 0000000..de64f44 --- /dev/null +++ b/src/transcriberbot/blueprints/chat_handlers.py @@ -0,0 +1,27 @@ +""" +Author: Carlo Alberto Barbano +Date: 16/02/25 +""" +import logging +import config + +from telegram import Update, ChatMember +from telegram.ext import ContextTypes + +from database import TBDB + + +async def chat_member_update(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + logging.log(config.APP_LOG, "Chat {chat_id} member update: %s", update) + + left = update.my_chat_member.new_chat_member.status in (ChatMember.LEFT, ChatMember.BANNED) + + if left: + TBDB.set_chat_active(chat_id, False) + logging.log(config.APP_LOG, f"Chat {chat_id} deactivated") + else: + chat_record = TBDB.get_chat_entry(chat_id) + if chat_record: + TBDB.set_chat_active(chat_id, 1) + logging.log(config.APP_LOG, f"Chat {chat_id} reactivated") diff --git a/src/transcriberbot/blueprints/commands.py b/src/transcriberbot/blueprints/commands.py new file mode 100644 index 0000000..297f3a6 --- /dev/null +++ b/src/transcriberbot/blueprints/commands.py @@ -0,0 +1,226 @@ +""" +Author: Carlo Alberto Barbano +Date: 15/02/25 +""" +import logging +import asyncio +import traceback +import datetime + +from telegram import Update +from telegram.ext import ContextTypes + +import config +import resources as R +import translator +from database import TBDB + + +async def start(update: Update, context: ContextTypes.DEFAULT_TYPE): + await welcome_message(update, context) + + +async def lang(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_lang = TBDB.get_chat_lang(update.effective_chat.id) + await context.bot.send_message( + update.effective_chat.id, R.get_string_resource("language_get", chat_lang).replace("{lang}", chat_lang) + ) + + +async def rate(update: Update, context: ContextTypes.DEFAULT_TYPE): + await context.bot.send_message( + update.effective_chat.id, + R.get_string_resource("message_rate", TBDB.get_chat_lang(update.effective_chat.id)) + ) + + +async def disable_voice(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + TBDB.set_chat_voice_enabled(chat_id, 0) + await context.bot.send_message( + chat_id, R.get_string_resource("voice_disabled", TBDB.get_chat_lang(chat_id)) + ) + + +async def enable_voice(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + TBDB.set_chat_voice_enabled(chat_id, 1) + await context.bot.send_message( + chat_id, R.get_string_resource("voice_enabled", TBDB.get_chat_lang(chat_id)) + ) + + +async def disable_photos(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + TBDB.set_chat_photos_enabled(chat_id, 0) + await context.bot.send_message( + chat_id, R.get_string_resource("photos_disabled", TBDB.get_chat_lang(chat_id)) + ) + + +async def enable_photos(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + TBDB.set_chat_photos_enabled(chat_id, 1) + await context.bot.send_message( + chat_id, R.get_string_resource("photos_enabled", TBDB.get_chat_lang(chat_id)) + ) + + +async def disable_qr(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + TBDB.set_chat_qr_enabled(chat_id, 0) + await context.bot.send_message( + chat_id, R.get_string_resource("qr_disabled", TBDB.get_chat_lang(chat_id)) + ) + + +async def enable_qr(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + TBDB.set_chat_qr_enabled(chat_id, 1) + await context.bot.send_message( + chat_id, R.get_string_resource("qr_enabled", TBDB.get_chat_lang(chat_id)) + ) + + +async def translate(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + + lang = update.effective_message.text + lang = lang.replace("/translate", "").strip() + logging.debug("Language %s", lang) + + if not update.effective_message.reply_to_message: + await context.bot.send_message( + chat_id, R.get_string_resource("translate_reply_to_message", TBDB.get_chat_lang(chat_id)) + ) + return + + if not lang: + await context.bot.send_message( + chat_id, R.get_string_resource("translate_language_missing", TBDB.get_chat_lang(chat_id)) + ) + return + + if lang not in config.get_config_prop("app")["languages"]: + await context.bot.send_message( + chat_id, R.get_string_resource("translate_language_not_found", TBDB.get_chat_lang(chat_id)).format(lang) + ) + return + + lang = config.get_config_prop("app")["languages"][lang].split('-')[0] + translation = translator.translate( + source=TBDB.get_chat_lang(chat_id), + target=lang, + text=update.effective_message.reply_to_message.text + ) + + await context.bot.send_message( + chat_id, translation, reply_to_message_id=update.effective_message.reply_to_message.message_id + ) + + +async def donate(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + await context.bot.send_message( + chat_id, R.get_string_resource("message_donate", TBDB.get_chat_lang(chat_id)), parse_mode="html" + ) + + +async def privacy(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + await context.bot.send_message( + chat_id, R.get_string_resource("privacy_policy", TBDB.get_chat_lang(chat_id)), parse_mode="html" + ) + + +async def welcome_message(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_record = TBDB.get_chat_entry(update.effective_chat.id) + + language = None + if chat_record is not None: + language = chat_record["lang"] + elif update.effective_user.language_code is not None: + # Channel posts do not have a language_code attribute + logging.debug("Language_code: %s", update.effective_user.language_code) + language = update.effective_user.language_code + + message = R.get_string_resource("message_welcome", language) + message = message.replace("{languages}", + "/" + "\n/".join(config.get_language_list())) # Format them to be a list of commands + + await context.bot.send_message(update.effective_chat.id, message, "html") + + if chat_record is None: + if language is None: + language = "en-US" + + if len(language) < 5: + language = R.iso639_2_to_639_1(language) + + logging.debug( + "No record found for chat {}, creating one with lang {}".format(update.effective_chat.id, language)) + TBDB.create_default_chat_entry(update.effective_chat.id, language) + + +async def set_language(update: Update, context: ContextTypes.DEFAULT_TYPE, language): + chat_id = update.effective_chat.id + lang_ = config.get_config_prop("app")["languages"][language] # ISO 639-1 code for language + TBDB.set_chat_lang(chat_id, lang_) + message = R.get_string_resource("language_set", lang_).replace("{lang}", language) + await context.bot.send_message(chat_id, message, parse_mode="html") + + +async def users(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + tot_chats = TBDB.get_chats_num() + active_chats = TBDB.get_active_chats_num() + await context.bot.send_message( + chat_id=chat_id, text='Total users: {}\nActive users: {}'.format(tot_chats, active_chats), parse_mode='html', + ) + + +async def stats(update: Update, context: ContextTypes.DEFAULT_TYPE): + num_audios = len(context.bot_data) - 1 + num_queues = context.bot_data.get('queue_len', 0) + + audio_queue = [f"{audio_id} duration {datetime.timedelta(seconds=v['duration'])} (received {v['time']}" for + audio_id, v in + context.bot_data.items() if + audio_id != 'queue_len'] + + await context.bot.send_message( + update.effective_chat.id, f"Number of audios being currently processed: {num_audios}\n" + f"Number of audios in queue: {num_queues}\n\n" + f"{'\n'.join(audio_queue)}" + ) + + +async def broadcast(update: Update, context: ContextTypes.DEFAULT_TYPE): + chat_id = update.effective_chat.id + text = " ".join(context.args) + + async def __post(): + chats = TBDB.get_chats() + sent = 0 + + for chat in chats: + try: + await context.bot.send_message( + chat_id=chat['chat_id'], + text=text, + parse_mode='html', + ) + sent += 1 + await asyncio.sleep(0.1) + except Exception as e: + logging.error( + "Exception sending broadcast to %d: (%s)", + chat['chat_id'], e, exc_info=True + ) + + await context.bot.send_message( + chat_id=chat_id, + text='Broadcast sent to {}/{} chats'.format(sent, len(chats)), + ) + + await __post() diff --git a/src/transcriberbot/blueprints/messages.py b/src/transcriberbot/blueprints/messages.py new file mode 100644 index 0000000..e9469b0 --- /dev/null +++ b/src/transcriberbot/blueprints/messages.py @@ -0,0 +1,16 @@ +""" +Author: Carlo Alberto Barbano +Date: 15/02/25 +""" +import resources as R + +from telegram import Update +from telegram.ext import ContextTypes +from database import TBDB + + +async def private_message(update: Update, context: ContextTypes.DEFAULT_TYPE): + await context.bot.send_message( + update.effective_chat.id, + R.get_string_resource("message_private", TBDB.get_chat_lang(update.effective_chat.id)) + ) diff --git a/src/transcriberbot/blueprints/photos.py b/src/transcriberbot/blueprints/photos.py new file mode 100644 index 0000000..b4cdf78 --- /dev/null +++ b/src/transcriberbot/blueprints/photos.py @@ -0,0 +1,75 @@ +""" +Author: Carlo Alberto Barbano +Date: 15/02/25 +""" +import html +import logging +import os +import traceback + +import telegram +from telegram import Update +from telegram.constants import ChatType +from telegram.ext import ContextTypes + +import config +import phototools +import resources as R +from database import TBDB + + +async def photo(update: Update, context: ContextTypes.DEFAULT_TYPE): + photo_enabled = update.effective_chat.type == ChatType.PRIVATE or TBDB.get_chat_photos_enabled(update.effective_chat.id) + qr_enabled = update.effective_chat.type == ChatType.PRIVATE or TBDB.get_chat_qr_enabled(update.effective_chat.id) + + if not photo_enabled and not qr_enabled: + return + + message = update.message or update.channel_post + await process_media_photo(update, context, message.photo) + + +async def process_media_photo(update: Update, context: ContextTypes.DEFAULT_TYPE, photo): + chat_id = update.effective_chat.id + message_id = update.effective_message.id + lang = TBDB.get_chat_lang(chat_id) + + file_id = photo[-1].file_id + file_path = os.path.join(config.get_config_prop("app")["media_path"], file_id) + file: telegram.File = await context.bot.get_file(file_id) + await file.download_to_drive(file_path) + + try: + if update.effective_chat.type == ChatType.PRIVATE or TBDB.get_chat_qr_enabled(update.effective_chat.id): + qr = phototools.read_qr(file_path) + if qr is not None: + qr = R.get_string_resource("qr_result", lang) + f"\n{qr}" + + await context.bot.send_message( + chat_id=chat_id, text=qr, reply_to_message_id=message_id, + parse_mode="html" + ) + return + + if update.effective_chat.type == ChatType.PRIVATE or TBDB.get_chat_photos_enabled(update.effective_chat.id): + text = phototools.image_ocr(file_path, lang) + if text is not None: + text = R.get_string_resource("ocr_result", lang) + "\n" + html.escape(text) + await context.bot.send_message( + text=text, chat_id=chat_id, reply_to_message_id=message_id, + parse_mode="html", + ) + return + + + await context.bot.send_message( + text=R.get_string_resource("photo_no_text", lang), + chat_id=chat_id, reply_to_message_id=message_id, + parse_mode="html", + ) + + except Exception as e: + logging.error("Exception handling photo from %d", chat_id, exc_info=True) + + finally: + os.remove(file_path) diff --git a/src/transcriberbot/blueprints/voice.py b/src/transcriberbot/blueprints/voice.py new file mode 100644 index 0000000..7c5bf90 --- /dev/null +++ b/src/transcriberbot/blueprints/voice.py @@ -0,0 +1,228 @@ +""" +Author: Carlo Alberto Barbano +Date: 15/02/25 +""" +import asyncio +import logging +import os +import traceback +import datetime +from asyncio import CancelledError + +import telegram +from telegram import Update, Voice, InlineKeyboardMarkup, InlineKeyboardButton, VideoNote, Document +from telegram.constants import ChatType +from telegram.ext import ContextTypes + +import audiotools +import config +import resources as R +from database import TBDB + +logger = logging.getLogger(__name__) + + +# TODO: check if cpu usage is too high, if so, use ProcessPoolExecutor + + +async def voice_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + if TBDB.get_chat_voice_enabled(update.effective_chat.id) == 0: + return + + await run_voice_task(update, context, update.effective_message.voice, "voice") + + +async def audio_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + if TBDB.get_chat_voice_enabled(update.effective_chat.id) == 0: + return + + await run_voice_task(update, context, update.effective_message.audio, "audio") + + +async def video_note_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + if TBDB.get_chat_voice_enabled(update.effective_chat.id) == 0: + return + + await run_voice_task(update, context, update.effective_message.video_note, "video_note") + + +async def document_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + if TBDB.get_chat_voice_enabled(update.effective_chat.id) == 0: + return + + await run_voice_task(update, context, update.effective_message.document, "document") + + +async def stop_task(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + task_id = int(update.callback_query.data) + task: asyncio.Task = context.bot_data.get(task_id)["task"] + + if task is not None: + task.cancel() + context.bot_data.pop(task_id) + else: + logging.warning("Task not found") + + +async def wait_for_task_queue(context: ContextTypes.DEFAULT_TYPE): + # wait until there are less than N tasks in bot_data + context.bot_data['queue_len'] = context.bot_data.get('queue_len', 0) + 1 + + while len(context.bot_data) >= config.get_config_prop("app")["voice_max_threads"] + 1: + logging.debug("Waiting for tasks to finish") + await asyncio.sleep(1) + + context.bot_data['queue_len'] -= 1 + logging.debug("Task queue has available space") + + +async def run_voice_task(update: Update, context: ContextTypes.DEFAULT_TYPE, media: Voice, + name): + await wait_for_task_queue(context) + + try: + task = asyncio.create_task(process_media_voice(update, context, media, name)) + context.bot_data[update.effective_message.message_id] = { + 'task': task, + 'duration': media.duration, + 'time': datetime.datetime.now(datetime.timezone.utc) + } + await asyncio.gather(task) + finally: + context.bot_data.pop(update.effective_message.message_id) + + +async def process_media_voice(update: Update, context: ContextTypes.DEFAULT_TYPE, media: [Voice | VideoNote | Document], + name: str) -> None: + chat_id = update.effective_chat.id + file_size = media.file_size + max_size = config.get_config_prop("app").get("max_media_voice_file_size", 20 * 1024 * 1024) + + if file_size > max_size: + error_message = R.get_string_resource("file_too_big", TBDB.get_chat_lang(chat_id)).format( + max_size / (1024 * 1024)) + "\n" + await context.bot.send_message( + chat_id, error_message, parse_mode="html", reply_to_message_id=update.effective_message.message_id + ) + return + + file_id = media.file_id + file_path = os.path.join(config.get_config_prop("app")["media_path"], file_id) + file: telegram.File = await context.bot.get_file(file_id) + await file.download_to_drive(file_path) + + try: + await transcribe_audio_file(update, context, file_path) + except Exception: + logger.error("Exception handling %s from %d", name, chat_id, exc_info=True) + finally: + os.remove(file_path) + + +async def transcribe_audio_file(update: Update, context: ContextTypes.DEFAULT_TYPE, path: str): + chat_id = update.effective_chat.id + task_id = update.effective_message.message_id + lang = TBDB.get_chat_lang(chat_id) + is_group = update.effective_chat.type != ChatType.PRIVATE + + api_key = config.get_config_prop("wit").get(lang, None) + if api_key is None: + logger.error("Language not found in wit.json %s", lang) + await context.bot.send_message( + chat_id, R.get_string_resource("unknown_api_key", lang).format(language=lang), parse_mode="html", + reply_to_message_id=update.effective_message.message_id + ) + return + + logger.debug("Using key %s for lang %s", api_key, lang) + + message = await context.bot.send_message( + chat_id, R.get_string_resource("transcribing", lang), parse_mode="html", + reply_to_message_id=update.effective_message.message_id + ) + + logger.debug("Starting task %d", task_id) + keyboard = InlineKeyboardMarkup( + [[InlineKeyboardButton("Stop", callback_data=task_id)]] + ) + + text = "" + if is_group: + text = R.get_string_resource("transcription_text", lang) + "\n" + + try: + async for idx, speech, n_chunks in audiotools.transcribe(path, api_key): + logging.debug(f"Transcription idx={idx} n_chunks={n_chunks}, text={speech}") + suffix = f" [{idx + 1}/{n_chunks}]" if idx < n_chunks - 1 else "" + reply_markup = keyboard if idx < n_chunks - 1 else None + + if len(text + " " + speech) >= 4000: + text = R.get_string_resource("transcription_continues", lang) + "\n" + message = await context.bot.send_message( + chat_id, f"{text} {speech} {suffix}", + reply_to_message_id=message.message_id, parse_mode="html", + reply_markup=reply_markup + ) + else: + message = await context.bot.edit_message_text( + f"{text} {speech} {suffix}", chat_id=chat_id, + message_id=message.message_id, parse_mode="html", + reply_markup=reply_markup + ) + + text = f"{text} {speech}" + + # retry_num = 0 + # retry = True + # while retry: # Retry loop + # try: + # if len(text + " " + speech) >= 4000: + # text = R.get_string_resource("transcription_continues", lang) + "\n" + # message = await context.bot.send_message( + # chat_id, f"{text} {speech} {suffix}", + # reply_to_message_id=message.message_id, parse_mode="html", + # reply_markup=keyboard + # ) + # else: + # message = await context.bot.edit_message_text( + # f"{text} {speech} {suffix}", chat_id=chat_id, + # message_id=message.message_id, parse_mode="html", + # reply_markup=keyboard + # ) + # + # text += " " + speech + # retry = False + # + # except telegram.error.TimedOut as e: + # print(e) + # logger.error("Timeout error %s", traceback.format_exc()) + # retry_num += 1 + # if retry_num >= 3: + # retry = False + # + # except telegram.error.RetryAfter as r: + # logger.warning("Retrying after %d", r.retry_after) + # await asyncio.sleep(r.retry_after) + # + # except telegram.error.TelegramError: + # logger.error("Telegram error %s", traceback.format_exc()) + # retry = False + + + except CancelledError: + logging.debug("Task cancelled") + await context.bot.edit_message_text( + message.text + " " + R.get_string_resource("transcription_stopped", lang), chat_id=chat_id, + message_id=message.message_id, parse_mode="html" + ) + return + + except Exception as e: + logger.error("Could not transcribe audio") + + await context.bot.edit_message_text( + R.get_string_resource("transcription_failed", lang), chat_id=chat_id, + message_id=message.message_id, parse_mode="html" + ) + + raise e diff --git a/src/transcriberbot/bot.py b/src/transcriberbot/bot.py index 71f5efa..06c170c 100644 --- a/src/transcriberbot/bot.py +++ b/src/transcriberbot/bot.py @@ -1,296 +1,76 @@ +""" +Author: Carlo Alberto Barbano +Date: 15/02/25 +""" +from telegram import Update + import config -import database -import resources as R -import metaclass -import functional -import audiotools -import phototools -import pprint import logging -import os -import traceback -import telegram -import time -import antiflood -import translator - -from datetime import datetime - -from database import TBDB -from telegram.ext import Updater, CommandHandler, MessageHandler, CallbackQueryHandler, Filters -from telegram.ext import messagequeue as mq -from telegram.utils.request import Request -from telegram.error import Unauthorized - -from concurrent.futures import ThreadPoolExecutor - -from transcriberbot import tbfilters -from transcriberbot.channel_command_handler import ChannelCommandHandler - -logger = logging.getLogger(__name__) - -# Utils -def get_language_list(): - return config.get_config_prop("app")["languages"].keys() - -def get_chat_id(update): - chat_id = None - if update.message is not None: - chat_id = update.message.chat.id - elif update.channel_post is not None: - chat_id = update.channel_post.chat.id - return chat_id - -def get_message_id(update): - if update.message is not None: - return update.message.message_id - elif update.channel_post is not None: - return update.channel_post.message_id - - return None - -class TranscriberBot(metaclass=metaclass.Singleton): - class MQBot(telegram.bot.Bot): - def __init__(self, *args, is_queued_def=True, mqueue=None, **kwargs): - super().__init__(*args, **kwargs) - self._is_messages_queued_default = is_queued_def - self._msg_queue = mqueue or mq.MessageQueue() - - chats = TBDB.get_chats() - self.active_chats_cache = dict(zip( - [c['chat_id'] for c in chats], - [c['active'] for c in chats] - )) - - def __del__(self): - try: - self._msg_queue.stop() - except: - pass - super().__del__() - - def active_check(self, fn, *args, **kwargs): - err = None - res = None - - try: - res = fn(*args, **kwargs) - except Unauthorized as e: - pprint.pprint(e) - logger.error(e) - err = e - - if err is not None: - chat_id = kwargs['chat_id'] - if chat_id not in self.active_chats_cache or self.active_chats_cache[chat_id] == 1: - logger.debug("Marking chat {} as inactive".format(chat_id)) - self.active_chats_cache[chat_id] = 0 - TBDB.set_chat_active(chat_id, self.active_chats_cache[chat_id]) - raise err - - return res - - @mq.queuedmessage - def send_message(self, *args, **kwargs): - return self.active_check(super().send_message, *args, **kwargs) - - @mq.queuedmessage - def edit_message_text(self, *args, **kwargs): - return self.active_check(super().edit_message_text, *args, **kwargs) - - def __init__(self): - self.error_handler = None - self.message_handlers = {} - self.command_handlers = {} - self.callback_handlers = {} - self.floods = {} - self.workers = {} - - antiflood.register_flood_warning_callback( - lambda chat_id: TranscriberBot.get().bot().send_message( - chat_id=chat_id, - text=R.get_string_resource("flood_warning", TBDB.get_chat_lang(chat_id)), - parse_mode = "html", - is_group = chat_id < 0 - ) - ) - - def flood_started(chat_id): - logger.info("Flood detected in %d, ignoring messages", chat_id) - self.floods[chat_id] = True - - def flood_ended(chat_id): - logger.info("Flood ended for %d", chat_id) - self.floods[chat_id] = False - - antiflood.register_flood_started_callback(flood_started) - antiflood.register_flood_ended_callback(flood_ended) - - @staticmethod - def get(): - return TranscriberBot() - - def bot(self): - return self.mqbot - - def start(self, token): - self.voice_thread_pool = ThreadPoolExecutor( - max_workers=config.get_config_prop("app")["voice_max_threads"] - ) - self.photos_thread_pool = ThreadPoolExecutor( - max_workers=config.get_config_prop("app")["photos_max_threads"] - ) - - self.misc_thread_pool = ThreadPoolExecutor( - max_workers=2 - ) - - self.queue = mq.MessageQueue() - self.request = Request(con_pool_size=10) - self.mqbot = self.MQBot(token, request=self.request, mqueue=self.queue) - self.updater = Updater(bot=self.mqbot, use_context=True) - self.dispatcher = self.updater.dispatcher - self.__register_handlers() - self.updater.start_polling(clean=True) - self.updater.idle() - - def register_message_handler(self, filter, fn): - self.message_handlers[filter] = fn - - def register_command_handler(self, fn, filters=None): - self.command_handlers[fn.__name__] = (fn, filters) - - def register_callback_handler(self, fn): - self.callback_handlers[fn.__name__] = fn - - def start_thread(self, id): - self.workers[str(id)] = True - - def stop_thread(self, id): - self.workers[str(id)] = False - - def thread_running(self, id): - return self.workers[str(id)] - - def del_thread(self, id): - del self.workers[str(id)] - - def __add_handler(self, handler): - self.dispatcher.add_handler(handler) - - def __add_error_handler(self, handler): - self.dispatcher.add_error_handler(handler) - - def __pre__hook(self, fn, u, c, **kwargs): - b = c.bot - - m = u.message or u.channel_post - if not m: - return - - age = (datetime.utcnow() - m.date.replace(tzinfo=None)).total_seconds() - if age > config.get_config_prop("app")["antiflood"]["age_threshold"]: - return - - chat_id = get_chat_id(u) - antiflood.on_chat_msg_received(chat_id) - - if chat_id in self.floods and self.floods[chat_id] is True: - return - - if not TBDB.get_chat_entry(chat_id): - # happens when welcome/joined message is not received - TBDB.create_default_chat_entry(chat_id, 'en-US') - - if chat_id in self.mqbot.active_chats_cache and self.mqbot.active_chats_cache[chat_id] == 0: - logger.debug("Marking chat {} as active".format(chat_id)) - self.mqbot.active_chats_cache[chat_id] = 1 - TBDB.set_chat_active(chat_id, self.mqbot.active_chats_cache[chat_id]) - - return fn(b, u, **kwargs) - - def __register_handlers(self): - functional.apply_fn( - self.message_handlers.items(), - lambda h: self.__add_handler(MessageHandler( - h[0], - lambda b, u, **kwargs: self.__pre__hook(h[1], b, u, **kwargs))) - ) - - functional.apply_fn( - self.command_handlers.items(), - lambda h: self.__add_handler(ChannelCommandHandler( - h[0], - lambda b, u, **kwargs: self.__pre__hook(h[1][0], b, u, **kwargs), - filters=h[1][1])) - ) - - functional.apply_fn( - self.callback_handlers.items(), - lambda h: self.__add_handler(CallbackQueryHandler(h[1])) - ) - -# Decorators for adding callbacks -def message(filter): - def decor(fn): - TranscriberBot.get().register_message_handler(filter, fn) - return decor - -def command(filters=None): - def decor(fn): - TranscriberBot.get().register_command_handler(fn, filters) - return decor - -def callback_query(fn): - TranscriberBot.get().register_callback_handler(fn) - -# Install language command callbacks -def language_handler(bot, update, language): - chat_id = get_chat_id(update) - lang = config.get_config_prop("app")["languages"][language] #ISO 639-1 code for language - TBDB.set_chat_lang(chat_id, lang) - message = R.get_string_resource("language_set", lang).replace("{lang}", language) - reply = update.message or update.channel_post - reply.reply_text(message) - -def install_language_handlers(language): - handler = lambda b, u: language_handler(b, u, language) - handler.__name__ = language - TranscriberBot.get().register_command_handler(handler, filters=tbfilters.chat_admin) - -# Init -def init(): - functional.apply_fn(get_language_list(), install_language_handlers) - -def welcome_message(bot, update): - chat_id = get_chat_id(update) - message_id = get_message_id(update) - chat_record = TBDB.get_chat_entry(chat_id) - - language = None - if chat_record is not None: - language = chat_record["lang"] - elif update.message is not None and update.message.from_user.language_code is not None: - # Channel posts do not have a language_code attribute - logger.debug("Language_code: %s", update.message.from_user.language_code) - language = update.message.from_user.language_code - - message = R.get_string_resource("message_welcome", language) - message = message.replace("{languages}", "/" + "\n/".join(get_language_list())) #Format them to be a list of commands - bot.send_message( - chat_id=chat_id, - text=message, - reply_to_message_id = message_id, - parse_mode = "html", - is_group = chat_id < 0 - ) - - if chat_record is None: - if language is None: - language = "en-US" - if len(language) < 5: - language = R.iso639_2_to_639_1(language) - - logger.debug("No record found for chat {}, creating one with lang {}".format(chat_id, language)) - TBDB.create_default_chat_entry(chat_id, language) \ No newline at end of file +from telegram.ext import MessageHandler, ApplicationBuilder, CommandHandler, ContextTypes, CallbackQueryHandler, \ + ChatMemberHandler +from functools import partial +from transcriberbot.blueprints import commands, messages, voice, photos, chat_handlers +from transcriberbot.blueprints.commands import set_language + +from telegram.ext.filters import VOICE, VIDEO_NOTE, AUDIO, PHOTO +from transcriberbot.filters import chat_admin, FromPrivate, AllowedDocument, BotAdmin + + +def run(bot_token: str): + application = (ApplicationBuilder() + .token(bot_token) + .concurrent_updates(True) + .build()) + + logging.log(config.APP_LOG, "Installing handlers") + application.add_handler(CallbackQueryHandler(voice.stop_task)) + + application.add_handler(ChatMemberHandler( + chat_handlers.chat_member_update, + chat_member_types=ChatMemberHandler.MY_CHAT_MEMBER + )) + + chat_admin_handlers = { + 'start': commands.start, + 'help': commands.start, + 'lang': commands.lang, + 'rate': commands.rate, + 'disable_voice': commands.disable_voice, + 'enable_voice': commands.enable_voice, + 'disable_photos': commands.disable_photos, + 'enable_photos': commands.enable_photos, + 'disable_qr': commands.disable_qr, + 'enable_qr': commands.enable_qr, + 'translate': commands.translate, + 'donate': commands.donate, + 'privacy': commands.privacy + } + + for command, callback in chat_admin_handlers.items(): + application.add_handler(CommandHandler(command, lambda u, c, cb=callback: chat_admin(u, c, cb))) + + logging.log(config.APP_LOG, "Installing language handlers..") + for language in config.get_language_list(): + callback = partial(set_language, language=language) + application.add_handler( + CommandHandler(language, lambda u, c, cb=callback: chat_admin(u, c, cb)) + ) + + logging.log(config.APP_LOG, "Installing admin controls") + application.add_handler(CommandHandler("users", commands.users, filters=BotAdmin())) + application.add_handler(CommandHandler("broadcast", commands.broadcast, filters=BotAdmin())) + application.add_handler(CommandHandler("stats", commands.stats, filters=BotAdmin())) + + logging.log(config.APP_LOG, "Installing message handlers") + application.add_handler(MessageHandler(VOICE, voice.voice_message)) + application.add_handler(MessageHandler(AUDIO, voice.audio_message)) + application.add_handler(MessageHandler(VIDEO_NOTE, voice.video_note_message)) + application.add_handler(MessageHandler(AllowedDocument(config.get_document_extensions()), voice.document_message)) + + application.add_handler(MessageHandler(PHOTO, photos.photo)) + + application.add_handler(MessageHandler(FromPrivate(), messages.private_message)) + + logging.log(config.APP_LOG, "Starting bot..") + application.run_polling(allowed_updates=Update.ALL_TYPES) diff --git a/src/transcriberbot/filters/__init__.py b/src/transcriberbot/filters/__init__.py new file mode 100644 index 0000000..4d5e935 --- /dev/null +++ b/src/transcriberbot/filters/__init__.py @@ -0,0 +1,5 @@ +""" +Author: Carlo Alberto Barbano +Date: 15/02/25 +""" +from .filters import * diff --git a/src/transcriberbot/filters/filters.py b/src/transcriberbot/filters/filters.py new file mode 100644 index 0000000..bc89d30 --- /dev/null +++ b/src/transcriberbot/filters/filters.py @@ -0,0 +1,101 @@ +""" +Author: Carlo Alberto Barbano +Date: 15/02/25 +""" +import logging +import asyncio + +from telegram.constants import ChatType +from telegram.ext import ContextTypes +from telegram.ext.filters import UpdateFilter +from telegram import Update, ChatMember + +import config + + +class AllowedDocument(UpdateFilter): + """ + Checks if the message has document media with allowed extensions. + """ + + def __init__(self, allowed_exts) -> None: + super().__init__() + self.allowed_exts = allowed_exts + if len(allowed_exts) == 0: + logging.warning("No allowed extensions were provided. Documents will be disabled") + + def filter(self, update: Update) -> bool: + if update.effective_message.animation: + return False + + if update.effective_message.document: + logging.debug("Received document %s", update.effective_message.document.file_id) + filename = update.effective_message.document.file_name + if '.' not in filename: # No extension + return False + ext = filename.split('.')[-1] + return ext in self.allowed_exts + return False + + +class FromPrivate(UpdateFilter): + """ + Checks if the message was sent in a private conversation. + """ + + def filter(self, update: Update) -> bool: + return update.effective_chat.type == ChatType.PRIVATE + + +class ChatAdmin(UpdateFilter): + """ + Checks if the message was sent by a chat admin. + """ + + def filter(self, update: Update) -> bool: + if update.effective_chat.type in (ChatType.PRIVATE, ChatType.CHANNEL): + return True + + user = update.effective_user + chat_admins: list[ChatMember] = asyncio.get_event_loop().run_until_complete( + update.effective_chat.get_administrators()) + + is_admin = list(filter(lambda admin: admin.user.id == user.id, chat_admins)) + is_admin = len(is_admin) > 0 + + return is_admin + + +async def chat_admin(update: Update, context: ContextTypes.DEFAULT_TYPE, callback): + + if update.effective_chat.type in (ChatType.PRIVATE, ChatType.CHANNEL): + is_admin = True + else: + user = update.effective_user + + if user.id == 1087968824: # Anonymous admin + is_admin = True + + else: + chat_admins: list[ChatMember] = await update.effective_chat.get_administrators() + + is_admin = list(filter(lambda admin: admin.user.id == user.id, chat_admins)) + is_admin = len(is_admin) > 0 + + if is_admin: + return await callback(update, context) + + +class BotAdmin(UpdateFilter): + """ + Checks if the message was sent by the bot admin. + """ + + def filter(self, update: Update) -> bool: + user = update.effective_user + bot_admins = config.get_bot_admins() + + is_admin = list(filter(lambda admin_id: admin_id == user.id, bot_admins)) + is_admin = len(is_admin) > 0 + + return is_admin diff --git a/src/transcriberbot/multiprocessing/__init__.py b/src/transcriberbot/multiprocessing/__init__.py new file mode 100644 index 0000000..616f1a1 --- /dev/null +++ b/src/transcriberbot/multiprocessing/__init__.py @@ -0,0 +1,5 @@ +""" +Author: Carlo Alberto Barbano +Date: 15/02/25 +""" +from .pools import voice_pool, init diff --git a/src/transcriberbot/multiprocessing/pools.py b/src/transcriberbot/multiprocessing/pools.py new file mode 100644 index 0000000..0af1db9 --- /dev/null +++ b/src/transcriberbot/multiprocessing/pools.py @@ -0,0 +1,34 @@ +""" +Author: Carlo Alberto Barbano +Date: 15/02/25 +""" +import logging +import config +from concurrent.futures import ThreadPoolExecutor + +voice_thread_pool, photos_thread_pool, misc_thread_pool = None, None, None + + +def init(): + global voice_thread_pool + global photos_thread_pool + global misc_thread_pool + + voice_thread_pool = ThreadPoolExecutor( + max_workers=config.get_config_prop("app")["voice_max_threads"] + ) + photos_thread_pool = ThreadPoolExecutor( + max_workers=config.get_config_prop("app")["photos_max_threads"] + ) + + misc_thread_pool = ThreadPoolExecutor( + max_workers=2 + ) + + print("POOLS INITIALIZED:", voice_thread_pool, photos_thread_pool, misc_thread_pool) + logging.info("Thread pools initialized") + + +def voice_pool(): + global voice_thread_pool + return voice_thread_pool diff --git a/src/translator/__init__.py b/src/translator/__init__.py index 114bcb2..b8e58e5 100644 --- a/src/translator/__init__.py +++ b/src/translator/__init__.py @@ -1 +1 @@ -from translator.translator import detect_language, translate \ No newline at end of file +from translator.translator import detect_language, translate diff --git a/src/translator/translator.py b/src/translator/translator.py index 305e46b..1a75d4f 100644 --- a/src/translator/translator.py +++ b/src/translator/translator.py @@ -6,38 +6,39 @@ def detect_language(text): - global yandex_detect_url + global yandex_detect_url - r = requests.post( - yandex_detect_url.format(config.get_config_prop("yandex")["translate_key"]), - data={'text': text} - ) - res = r.json() + r = requests.post( + yandex_detect_url.format(config.get_config_prop("yandex")["translate_key"]), + data={'text': text} + ) + res = r.json() - if 'lang' in res: - return res['lang'] - else: - return None + if 'lang' in res: + return res['lang'] + else: + return None def translate(source, target, text): - global yandex_translate_url - - autodetect = detect_language(text) - - if autodetect is not None: - source = autodetect - print("Autodetected language: {0}".format(autodetect)) - - lang = source + "-" + target - print(lang) - - r = requests.post( - yandex_translate_url.format(config.get_config_prop("yandex")["translate_key"]), - data={'lang': lang, 'text': text} - ) - - print(r) - res = r.json() - print(res) - return str(res['text'][0]) + "\n\nPowered by Yandex.Translate http://translate.yandex.com" + return "Translation service currently unavailable." + global yandex_translate_url + + autodetect = detect_language(text) + + if autodetect is not None: + source = autodetect + print("Autodetected language: {0}".format(autodetect)) + + lang = source + "-" + target + print(lang) + + r = requests.post( + yandex_translate_url.format(config.get_config_prop("yandex")["translate_key"]), + data={'lang': lang, 'text': text} + ) + + print(r) + res = r.json() + print(res) + return str(res['text'][0]) + "\n\nPowered by Yandex.Translate http://translate.yandex.com" diff --git a/values/strings.xml b/values/strings.xml index a453e7f..cba0064 100644 --- a/values/strings.xml +++ b/values/strings.xml @@ -2,6 +2,7 @@ Language set to {lang} +Current language: {lang} This bot transcribes audio and pictures into text. Add it to a group or forward audio messages and pictures to it. @@ -60,6 +61,7 @@ LTC: {code}LdsVPxqHR6PuKeMNvGYmMEkBRQ7M3AP3uY{/code} {b}WARNING:{/b} Flood detected. Stop spamming please Unknown language: {language} +Please specify a language to translate to, e.g. "/translate english" You must reply to a message in order to translate it diff --git a/values/strings_de-DE.xml b/values/strings_de-DE.xml index d33970f..c869218 100644 --- a/values/strings_de-DE.xml +++ b/values/strings_de-DE.xml @@ -1,6 +1,6 @@ - + {lang} als Sprache festgelegt