Skip to content

Commit 743e698

Browse files
authored
Merge pull request #267 from nyu-mlab/launch
Prolific Launch
2 parents b6f2d5e + 28ab2ab commit 743e698

File tree

6 files changed

+591
-438
lines changed

6 files changed

+591
-438
lines changed

.github/workflows/create_release.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,4 @@ jobs:
113113
generate_release_notes: true
114114
prerelease: false
115115
files: |
116-
dist/*.exe
117116
dist/*.whl

src/libinspector/common.py

Lines changed: 156 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
import streamlit as st
99
import logging
1010
import re
11+
import matplotlib.pyplot as plt
1112
import libinspector.global_state
1213
from libinspector.privacy import is_ad_tracked
1314

15+
1416
config_file_name = 'config.json'
1517
config_lock = threading.Lock()
1618
config_dict = {}
@@ -28,7 +30,8 @@
2830

2931
def remove_warning():
3032
"""
31-
Remove the warning acceptance state, forcing the user to see the warning again.
33+
Remove the warning screen by setting the suppress_warning flag in the config.
34+
This lasts for the duration of this IoT Inspector session.
3235
"""
3336
config_set("suppress_warning", True)
3437

@@ -40,17 +43,6 @@ def reset_prolific_id():
4043
config_set("prolific_id", "")
4144

4245

43-
def set_prolific_id(prolific_id: str):
44-
"""
45-
Store the provided Prolific ID in the configuration.
46-
47-
Args:
48-
prolific_id (str): The Prolific ID to store.
49-
"""
50-
if is_prolific_id_valid(prolific_id):
51-
config_set("prolific_id", prolific_id)
52-
53-
5446
def show_warning():
5547
"""
5648
Displays a warning message to the user about network monitoring and ARP spoofing.
@@ -62,10 +54,11 @@ def show_warning():
6254
"""
6355
current_id = config_get("prolific_id", "")
6456
st.subheader("1. Prolific ID Confirmation")
65-
st.info(f"Your currently stored ID is: `{current_id}`")
66-
st.button("Change Prolific ID",
67-
on_click=reset_prolific_id,
68-
help="Clicking this will clear your stored ID and return you to the ID entry form.")
57+
if current_id != "":
58+
st.info(f"Your currently stored ID is: `{current_id}`")
59+
st.button("Change Prolific ID",
60+
on_click=reset_prolific_id,
61+
help="Clicking this will clear your stored ID and return you to the ID entry form.")
6962

7063
# --- GATE 1: PROLIFIC ID CHECK (Must be valid to proceed to confirmation) ---
7164
if is_prolific_id_valid(current_id):
@@ -96,10 +89,19 @@ def show_warning():
9689
value="",
9790
key="prolific_id_input"
9891
).strip()
99-
st.form_submit_button("Submit ID",
100-
on_click=set_prolific_id,
101-
args=(input_id,),
102-
help="Submit your Prolific ID to proceed.")
92+
submitted = st.form_submit_button("Submit ID", help="Submit your Prolific ID to proceed.")
93+
94+
if submitted:
95+
if is_prolific_id_valid(input_id):
96+
# 1. Set the valid ID
97+
config_set("prolific_id", input_id)
98+
st.success("Prolific ID accepted. Please review the details below.")
99+
100+
# 2. Rerun the script. In the next run, is_prolific_id_valid(current_id)
101+
# will be True, and the user jumps to the warning acceptance (Gate 2).
102+
st.rerun()
103+
else:
104+
st.error("Invalid Prolific ID. Must be 1-50 alphanumeric characters.")
103105

104106
return True # BLOCK: ID check still needs resolution.
105107

@@ -130,54 +132,157 @@ def is_prolific_id_valid(prolific_id: str) -> bool:
130132
return True
131133

132134

133-
def bar_graph_data_frame(mac_address: str, now: int):
135+
@st.cache_data(ttl=1, show_spinner=False)
136+
def bar_graph_data_frame(now: int):
137+
"""
138+
Retrieves and processes network flow data for ALL non-gateway devices
139+
for the last 60 seconds, performing zero-filling directly in the SQL query.
140+
141+
Args:
142+
now (int): The current epoch timestamp.
143+
Returns:
144+
(pd.DataFrame, pd.DataFrame): DataFrames for upload and download traffic,
145+
containing 'mac_address', 'seconds_ago', and 'Bits'.
146+
"""
134147
sixty_seconds_ago = now - 60
135-
db_conn, rwlock = libinspector.global_state.db_conn_and_lock
148+
# Parameters array: [?1/now, ?2/sixty_seconds_ago]
149+
params = [now, sixty_seconds_ago]
136150

151+
# --- SQL for Upload Chart: Includes mac_address and Zero-Filling for all devices ---
137152
sql_upload_chart = """
138-
SELECT timestamp, SUM (byte_count) * 8 AS Bits
139-
FROM network_flows
140-
WHERE src_mac_address = ?
141-
AND timestamp >= ?
142-
GROUP BY timestamp
143-
ORDER BY timestamp DESC
144-
"""
145-
153+
WITH RECURSIVE seq(s) AS (
154+
-- Generate 60 time slots (0 to 59 seconds ago)
155+
SELECT 0
156+
UNION ALL
157+
SELECT s + 1 FROM seq WHERE s < 59
158+
),
159+
device_macs AS (
160+
-- Select all non-gateway MAC addresses to include in the template
161+
SELECT mac_address FROM devices WHERE is_gateway = 0
162+
),
163+
zero_fill_template AS (
164+
-- Create a template of (MAC, seconds_ago) for every device (60 * N rows)
165+
SELECT T1.mac_address, T2.s AS seconds_ago
166+
FROM device_macs T1, seq T2
167+
),
168+
agg AS (
169+
-- Aggregate ACTUAL traffic, grouped by source MAC and seconds_ago
170+
SELECT src_mac_address AS mac_address,
171+
CAST((?1 - timestamp) AS INTEGER) AS seconds_ago,
172+
SUM(byte_count) * 8 AS Bits
173+
FROM network_flows
174+
WHERE timestamp >= ?2
175+
GROUP BY src_mac_address, seconds_ago
176+
)
177+
-- Join the template (ensuring zero fill) with the aggregated data
178+
SELECT z.mac_address,
179+
z.seconds_ago,
180+
COALESCE(a.Bits, 0) AS Bits
181+
FROM zero_fill_template z
182+
LEFT JOIN agg a
183+
ON z.mac_address = a.mac_address AND z.seconds_ago = a.seconds_ago
184+
ORDER BY z.mac_address, z.seconds_ago DESC
185+
"""
186+
187+
# --- SQL for Download Chart: Includes mac_address and Zero-Filling for all devices ---
146188
sql_download_chart = """
147-
SELECT timestamp, SUM (byte_count) * 8 AS Bits
148-
FROM network_flows
149-
WHERE dest_mac_address = ?
150-
AND timestamp >= ?
151-
GROUP BY timestamp
152-
ORDER BY timestamp DESC
153-
"""
189+
WITH RECURSIVE seq(s) AS (
190+
SELECT 0
191+
UNION ALL
192+
SELECT s + 1 FROM seq WHERE s < 59
193+
),
194+
device_macs AS (
195+
SELECT mac_address FROM devices WHERE is_gateway = 0
196+
),
197+
zero_fill_template AS (
198+
SELECT T1.mac_address, T2.s AS seconds_ago
199+
FROM device_macs T1, seq T2
200+
),
201+
agg AS (
202+
-- Aggregate ACTUAL traffic, grouped by destination MAC and seconds_ago
203+
SELECT dest_mac_address AS mac_address,
204+
CAST((?1 - timestamp) AS INTEGER) AS seconds_ago,
205+
SUM(byte_count) * 8 AS Bits
206+
FROM network_flows
207+
WHERE timestamp >= ?2
208+
GROUP BY dest_mac_address, seconds_ago
209+
)
210+
SELECT z.mac_address,
211+
z.seconds_ago,
212+
COALESCE(a.Bits, 0) AS Bits
213+
FROM zero_fill_template z
214+
LEFT JOIN agg a
215+
ON z.mac_address = a.mac_address AND z.seconds_ago = a.seconds_ago
216+
ORDER BY z.mac_address, z.seconds_ago DESC
217+
"""
154218

219+
db_conn, rwlock = libinspector.global_state.db_conn_and_lock
155220
with rwlock:
156-
df_upload_bar_graph = pd.read_sql_query(sql_upload_chart, db_conn,
157-
params=(mac_address, sixty_seconds_ago))
158-
df_download_bar_graph = pd.read_sql_query(sql_download_chart, db_conn,
159-
params=(mac_address, sixty_seconds_ago))
221+
df_upload_bar_graph = pd.read_sql_query(sql_upload_chart, db_conn, params=params)
222+
df_download_bar_graph = pd.read_sql_query(sql_download_chart, db_conn, params=params)
223+
160224
return df_upload_bar_graph, df_download_bar_graph
161225

162226

163-
def plot_traffic_volume(df: pd.DataFrame, now: int, chart_title: str):
227+
def plot_traffic_volume(df: pd.DataFrame, chart_title: str, full_width: bool = False):
164228
"""
165229
Plots the traffic volume over time. The bar goes from right to left,
166230
like Task Manager in Windows.
167231
168232
Args:
169233
df (pd.DataFrame): DataFrame containing 'Time' and 'Bits' columns.
170-
now: The current epoch time from which the sql query was executed
171234
chart_title: The title to display above the chart.
235+
full_width: Whether to use full width for the chart (True) or a smaller size (False).
172236
"""
173237
if df.empty:
174238
st.caption("No traffic data to display in chart.")
239+
return
240+
241+
# 1. Prepare data and sort: Sort by seconds_ago descending.
242+
# This places the older data on the left and the most recent data on the right.
243+
df_plot = df.drop(columns=['mac_address'])
244+
df_plot = df_plot.sort_values(by='seconds_ago', ascending=False).reset_index(drop=True)
245+
246+
# 2. Create the Matplotlib figure
247+
if full_width:
248+
# Larger figure size for full-page view
249+
fig, ax = plt.subplots(figsize=(16, 5))
175250
else:
176-
st.markdown(f"#### {chart_title}")
177-
df['seconds_ago'] = now - df['timestamp'].astype(int)
178-
df_reindexed = df.set_index('seconds_ago').reindex(range(0, 60), fill_value=0).reset_index()
179-
df_reindexed = df_reindexed.sort_values(by='seconds_ago', ascending=False)
180-
st.bar_chart(df_reindexed.set_index('seconds_ago')['Bits'], width='content')
251+
# Smaller figure size for two-column view (like in the card)
252+
fig, ax = plt.subplots(figsize=(10, 4))
253+
254+
# 3. Create the bar chart
255+
# Use the index as the x-position, and 'Bits' as the height.
256+
bars = ax.bar(df_plot.index, df_plot['Bits'], color='#1f77b4') # Streamlit blue
257+
258+
# Optional: Add clear labels to the bars for visual clarity
259+
for bar in bars:
260+
# Only label the tallest bars to prevent clutter
261+
if bar.get_height() > df_plot['Bits'].max() * 0.1:
262+
ax.text(bar.get_x() + bar.get_width() / 2., bar.get_height(),
263+
f'{bar.get_height():.0f}',
264+
ha='center', va='bottom', fontsize=8)
265+
266+
# 4. Set labels and title
267+
ax.set_title(chart_title, fontsize=14)
268+
ax.set_ylabel('Traffic Volume (Bits)', fontsize=10)
269+
270+
# Set X-axis ticks to show the actual 'seconds_ago' values
271+
# We select a few points to label to keep the axis clean.
272+
tick_positions = df_plot.index[::max(1, len(df_plot) // 8)]
273+
# Ensure tick labels are formatted as 'Xs'
274+
tick_labels = [f"{s}s" for s in df_plot['seconds_ago'].iloc[tick_positions]]
275+
ax.set_xticks(tick_positions)
276+
ax.set_xticklabels(tick_labels, rotation=45, ha="right", fontsize=8)
277+
ax.set_xlabel('Time (Seconds Ago)', fontsize=10)
278+
279+
# Clean up the plot
280+
plt.tight_layout()
281+
282+
# 5. Display the Matplotlib figure using st.pyplot
283+
st.pyplot(fig, clear_figure=True, width="content")
284+
# Important: close the figure to free up memory
285+
plt.close(fig)
181286

182287

183288
def get_device_metadata(mac_address: str) -> dict:
@@ -204,7 +309,7 @@ def get_device_metadata(mac_address: str) -> dict:
204309
return dict()
205310

206311

207-
def get_remote_hostnames(mac_address: str):
312+
def get_remote_hostnames(mac_address: str) -> str:
208313
"""
209314
Retrieve all distinct remote hostnames associated with a device's MAC address from network flows.
210315
@@ -247,7 +352,7 @@ def get_human_readable_time(timestamp=None):
247352
"""
248353
if timestamp is None:
249354
timestamp = time.time()
250-
return datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
355+
return datetime.datetime.fromtimestamp(timestamp).strftime("%b %d,%Y %I:%M:%S%p")
251356

252357

253358
def initialize_config_dict():
@@ -323,7 +428,7 @@ def config_set(key: str, value: typing.Any):
323428

324429
# Write the updated config_dict to the file
325430
with open(config_file_name, 'w') as f:
326-
json.dump(config_dict, f, indent=2, sort_keys=True)
431+
json.dump(config_dict, f, indent=4, sort_keys=True)
327432

328433

329434
def get_device_custom_name(mac_address: str) -> str:

0 commit comments

Comments
 (0)