Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 84 additions & 43 deletions app/nitmre_nlp_utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import pandas as pd
import numpy as np
from tqdm import tqdm


Expand All @@ -13,14 +14,16 @@ def _tokenize(msg: str) -> set[str]:
/ often used between acronyms in an explicit/implied coordination
line (e.g., DO/LL)
"""
p = re.compile('[^\w/#@]+')
return set(p.split(msg)) # only keep unique tokens to avoid revisiting during preprocessing
p = re.compile(r'[^\w/#@]+')
return set(
p.split(msg)
) # only keep unique tokens to avoid revisiting during preprocessing


def _acronym_repl_helper(m: re.Match, token_expanded: str) -> str:
def _acronym_repl_helper(m: re.Match[str], token_expanded: str) -> str:
"""Helper function that allows us to preserve plurality when expanding
acronyms.

Can be extended with compiled regex to do additional processing on
acronyms."""
result = f' {token_expanded}'
Expand All @@ -32,7 +35,7 @@ def preprocess_message(
icao_dictionary: dict[str, str],
msg: str,
*,
msg_only: bool=False,
msg_only: bool = False,
) -> str | tuple[str, list[str], list[str]]:
"""Extract all tokens from a message and replace instances with expanded
acronyms.
Expand All @@ -43,7 +46,8 @@ def preprocess_message(
# Any acronyms found that match a token will be expanded;
# matches are case-sensitive.
msg_expanded = msg
call_signs, icaos = [], []
call_signs: list[str] = []
icaos: list[str] = []
for token in tokens:
# compile the token as a regex pattern
p = re.compile(r'{}'.format(token))
Expand All @@ -62,8 +66,8 @@ def preprocess_message(
token_expanded = acronym_dictionary[token]
p = re.compile(r'\s' + r'{}'.format(token) + r's?')
msg_expanded = p.sub(
lambda m: _acronym_repl_helper(m, token_expanded),
msg_expanded)
lambda m: _acronym_repl_helper(m, token_expanded), msg_expanded
)

# RCH call signs
else:
Expand All @@ -84,28 +88,60 @@ def preprocess_message(
return msg_expanded if msg_only else (msg_expanded, call_signs, icaos)


def _thread_message(
root_message: str,
df_thread: pd.DataFrame,
message_col_name: str,
) -> str:
"""Take a root message and a dataframe of messages in the root's thread to
create a single threaded message.
def _sep_roots_and_threads(
df: pd.DataFrame,
) -> tuple[pd.DataFrame, pd.DataFrame]:
df_roots = df[(df['root_id'] == '') | (df['root_id'].isna())].copy(
deep=True
)
df_threads = df[(df['root_id'] != '') & ~(df['root_id'].isna())].copy(
deep=True
)

# To make sure the thread is in chronological order.
df_threads.sort_values(by='create_at', ascending=True, inplace=True) # type: ignore

return df_roots, df_threads


Use message_col_name to specify the column that contains the messages in
the thread."""
messages = [root_message]
def _fix_missing_roots(df: pd.DataFrame) -> pd.DataFrame:
"""Fix the dataframe so that missing root_id messages don't cause their
threads to be discarded.

for _, row in df_thread.iterrows():
message = row[message_col_name]
messages.append(message)
We might lose the root message for various reasons (e.g., weak learner,
root message dated outside of dataset's date range), but we don't want to
lose those messages.

return '\n'.join(messages)
This fix takes the earliest message in each thread (by create_at value), and
sets that id as the new root_id for all messages in the thread.
"""
df_roots, df_threads = _sep_roots_and_threads(df)

# Identify the ids of the root messages that are not in the dataframe
root_ids = set(df_roots['id'])
thread_root_ids = set(df_threads['root_id'])
missing_root_ids = thread_root_ids - root_ids
df_missing = df_threads[df_threads['root_id'].isin(missing_root_ids)] # type: ignore
print(df_missing)

# Take the earliest message in each thread and make that message the root
idx_min = df_missing.groupby('root_id')['create_at'].transform('idxmin') # type: ignore
new_root_ids = df_missing.loc[idx_min, 'id'].values # type: ignore
df_missing.loc[:, 'root_id'] = new_root_ids
df_missing.loc[:, 'root_id'] = np.where(
df_missing['id'] == df_missing['root_id'], '', df_missing['root_id']
)

# Update the original dataframe
df.update(df_missing) # type: ignore

return df


def convert_conversation_threads(
df: pd.DataFrame,
message_col_name: str,
id_col_name: str = 'id',
) -> pd.DataFrame:
"""Take a full set of individual messages to form single message threads.

Expand All @@ -118,28 +154,33 @@ def convert_conversation_threads(
- create_at
- the given message_col_name
"""
req_cols_set = {'id', 'root_id', 'create_at', message_col_name}
req_cols_set = {id_col_name, 'root_id', 'create_at', message_col_name}
if len(req_cols_set & set(df.columns)) != len(req_cols_set):
raise KeyError(f'Invalid dataframe. Required columns: {req_cols_set}.')

print('Converting raw messages to conversation threads.')

# Separate the root message from messages in the thread.
df_roots, df_threads = df[df['root_id'] == ''], df[df['root_id'] != '']

for root_id in tqdm(df_roots['id'].to_list()):
df_root = df_roots[df_roots['id'] == root_id]
df_thread = df_threads[df_threads['root_id'] == root_id].copy()

# Preserve chronological ordering of each message in the thread.
df_thread.sort_values(by='create_at', ascending=True, inplace=True)

threaded_message = _thread_message(
df_root.iloc[0][message_col_name], df_thread, message_col_name)

# Replace the original message with the thread.
df_roots.loc[
df_roots['id'] == root_id, message_col_name
] = threaded_message

return df_roots

df = _fix_missing_roots(df)
df_roots, df_threads = _sep_roots_and_threads(df)

# Grouping and mapping messages outside of the loop is more efficient than
# filtering for each root id in every iteration.
grouped_threads = df_threads.groupby('root_id')[message_col_name].apply( # type: ignore
list
)
root_messages: dict[str, str] = df_roots.set_index(id_col_name)[ # type: ignore
message_col_name
].to_dict()

for root_id in tqdm(df_roots[id_col_name].to_list()):
# Messages without a thread are untouched.
if root_id in grouped_threads:
root_message = root_messages[root_id]
thread_messages = grouped_threads[root_id]
all_messages = [root_message] + thread_messages
threaded = '\n'.join(all_messages)

# Replace the original message with the thread.
df_roots.loc[df_roots[id_col_name] == root_id, message_col_name] = threaded

return df_roots # rows are in their original order
24 changes: 24 additions & 0 deletions app/ppg_common/services/mattermost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,27 @@ def get_all_team_posts_by_substring(mm_base_url, mm_token, team_id, search_str):
logger.error(f"{resp.url} request failed: {resp.status_code}")

return ddf

def get_post(mm_base_url, mm_token, post_id, get_thread=False):
"""get a single post"""

rdf = pd.DataFrame()
ustr = f'{mm_base_url}/api/v4/posts/{post_id}'
if get_thread:
ustr = f'{ustr}/thread'
resp = requests.get(ustr,
headers={'Authorization': f'Bearer {mm_token}'},
timeout=HTTP_REQUEST_TIMEOUT_S)
if resp.status_code < 400:
pdata = resp.json()
# print(pdata)
if get_thread:
pdata = pdata['posts']
rdf = pd.DataFrame(pdata).transpose()
# display(rdf)
else:
rdf = pd.DataFrame([pdata])
else:
logger.error(f"{resp.url} request failed: {resp.status_code}")

return rdf
91 changes: 91 additions & 0 deletions tests/nitmre_nlp_utils/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
import random
import pandas as pd
import numpy as np

from datetime import datetime
from pandas.testing import assert_frame_equal

from app.nitmre_nlp_utils import preprocess as pre


def generate_random_timestamp(
start_dt: datetime = datetime(2000, 1, 1),
end_dt: datetime = datetime(2024, 12, 31),
) -> datetime:
sts = start_dt.timestamp()
ets = end_dt.timestamp()

random_dt = random.uniform(sts, ets)
return datetime.fromtimestamp(random_dt)


def test_threading_successful():
df = pd.DataFrame(
{
'id': [0, 1, 2, 3, 4, 5, 7, 9, 10, 11],
'root_id': ['', np.nan, 0, 0, 1, '', 6, 8, 8, 8],
'create_at': [
datetime(2024, 1, 1),
datetime(2024, 2, 1),
datetime(2024, 4, 1),
datetime(2024, 3, 1),
datetime(2024, 5, 1),
datetime(2024, 6, 1),
datetime(2024, 8, 1),
datetime(2024, 10, 1),
datetime(2024, 9, 1),
datetime(2024, 11, 1),
],
'message': [
'zero',
'one',
'two',
'three',
'four',
'five',
'seven',
'nine',
'ten',
'eleven',
],
}
)

expected_root_ids = (0, 1, 5, 7, 10)
expected = pd.DataFrame(
{
'id': expected_root_ids, # empty string and nans collapse to root id
'root_id': ['', np.nan, '', '', ''],
'create_at': df[df['id'].isin(expected_root_ids)]['create_at'], # type: ignore
'message': [
'zero\nthree\ntwo', # threads are sorted by create_at
'one\nfour', # threaded messages are delimited by newlines
'five', # keep messages that don't have a thread
'seven', # Threads with a missing root message aren't lost...
'ten\nnine\neleven', # ...even with more than one message in the thread.
],
}
)
result = pre.convert_conversation_threads(df, 'message')

assert_frame_equal(result, expected)


def test_threading_incorrect_columns():
message_col_name = 'messages'
num_rows = 5

df = pd.DataFrame(
{
'idx': np.arange(num_rows),
'root-id': np.arange(num_rows - 1, -1, -1),
'created_at': list(
generate_random_timestamp() for _ in range(num_rows)
),
message_col_name: list(str(i) for i in range(num_rows)),
}
)

with pytest.raises(KeyError):
pre.convert_conversation_threads(df, message_col_name)