Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions oasis/social_platform/recsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,25 @@
date_score = []


def _extract_trace_post_id(trace: Dict[str, Any]) -> Any | None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use extract_trace_post_id instead of _extract_trace_post_id may better align with other functions, what do you think?

"""Safely extract a post_id from a trace row's serialized info payload."""
info = trace.get('info')
if info is None:
return None

if isinstance(info, dict):
return info.get('post_id')

try:
parsed_info = literal_eval(info)
except (ValueError, SyntaxError):
return None

if isinstance(parsed_info, dict):
return parsed_info.get('post_id')
return None


def get_twhin_tokenizer():
global twhin_tokenizer
if twhin_tokenizer is None:
Expand Down Expand Up @@ -367,10 +386,13 @@ def get_like_post_id(user_id, action, trace_table):
list: List of post IDs.
"""
# Get post IDs from trace table for the given user and action
trace_post_ids = [
literal_eval(trace['info'])["post_id"] for trace in trace_table
if (trace['user_id'] == user_id and trace['action'] == action)
]
trace_post_ids = []
for trace in trace_table:
if trace['user_id'] != user_id or trace['action'] != action:
continue
post_id = _extract_trace_post_id(trace)
if post_id is not None:
trace_post_ids.append(post_id)
"""Only take the last 5 liked posts, if not enough, pad with the most
recently liked post. Only take IDs, not content, because calculating
embeddings for all posts again is very time-consuming, especially when the
Expand Down Expand Up @@ -666,10 +688,13 @@ def get_trace_contents(user_id, action, post_table, trace_table):
list: List of post contents.
"""
# Get post IDs from trace table for the given user and action
trace_post_ids = [
trace['post_id'] for trace in trace_table
if (trace['user_id'] == user_id and trace['action'] == action)
]
trace_post_ids = []
for trace in trace_table:
if trace['user_id'] != user_id or trace['action'] != action:
continue
post_id = _extract_trace_post_id(trace)
if post_id is not None:
trace_post_ids.append(post_id)
# Fetch post contents from post table where post IDs match those in the
# trace
trace_contents = [
Expand Down Expand Up @@ -784,8 +809,10 @@ def rec_sys_personalized_with_trace(
swap_free_ids = [
post_id for post_id in post_ids
if post_id not in rec_post_ids and post_id not in [
trace['post_id']
for trace in trace_table if trace['user_id']
trace_post_id for trace_post_id in (
_extract_trace_post_id(trace)
for trace in trace_table
) if trace_post_id is not None
Comment on lines +812 to +815
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic changed compare to original implementation. In the original implementation, it is to get the trace['post_id'] if trace['user_id']. But here it is to get the trace['info']['post_id']. Could you please check and explain more?

]
]
rec_post_ids = swap_random_posts(rec_post_ids, swap_free_ids,
Expand Down
46 changes: 46 additions & 0 deletions test/infra/recsys/test_recsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import random

from oasis.social_platform.recsys import (rec_sys_personalized,
rec_sys_personalized_twh,
rec_sys_personalized_with_trace,
rec_sys_random, rec_sys_reddit,
reset_globals)

Expand Down Expand Up @@ -244,6 +247,49 @@ def test_rec_sys_personalized_sample_posts():
assert result[i] == ["1", "2"]


def test_rec_sys_personalized_with_trace_ignores_trace_rows_without_post_id():
user_table = [{
"user_id": 1,
"bio": "I like cats",
}]
post_table = [
{
"post_id": 1,
"user_id": 2,
"content": "Cats are great",
},
{
"post_id": 2,
"user_id": 3,
"content": "Dogs are great",
},
{
"post_id": 3,
"user_id": 4,
"content": "Birds are great",
},
]
trace_table = [{
"user_id": 1,
"created_at": 1,
"action": "refresh",
"info": '{"content": "timeline refreshed"}',
}]
rec_matrix = [[], []]

random.seed(0)
result = rec_sys_personalized_with_trace(user_table,
post_table,
trace_table,
rec_matrix,
max_rec_post_len=2,
swap_rate=0.5)

assert len(result) == 1
assert len(result[0]) == 2
assert set(result[0]).issubset({1, 2, 3})


def test_rec_sys_personalized_twhin_sample_posts():
# Test the scenario when the number of tweets is greater than the maximum
# recommendation length
Expand Down
Loading