Skip to content

Commit 46bee46

Browse files
benjipelletierfacebook-github-bot
authored andcommitted
Add tracing sqlite db file support so we can connect from python (#807)
Summary: Pull Request resolved: #807 POC of Rust-create sqlite db accessable from native python sqlite connection. This allows us to write python tests that test Monarch and assert user (or BE) events using sql queries. * Uses reloadable layer to inject SqliteLayer into tracing registry on demand. * Exposes `with_tracing_db_file` to python to create and get the DB file name so we can connect to Reviewed By: eliothedeman, pablorfb-meta Differential Revision: D79761474 fbshipit-source-id: 4cabfffe3f591efeeaf943dee18dffe3bec4cb0d
1 parent 855cbbb commit 46bee46

File tree

4 files changed

+310
-25
lines changed

4 files changed

+310
-25
lines changed

hyperactor_telemetry/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ use tracing_subscriber::fmt::format::Writer;
6464
use tracing_subscriber::registry::LookupSpan;
6565

6666
use crate::recorder::Recorder;
67+
use crate::sqlite::get_reloadable_sqlite_layer;
6768

6869
pub trait TelemetryClock {
6970
fn now(&self) -> tokio::time::Instant;
@@ -563,6 +564,8 @@ pub fn initialize_logging_with_log_prefix(
563564
.with_target("opentelemetry", LevelFilter::OFF), // otel has some log span under debug that we don't care about
564565
);
565566

567+
let sqlite_layer = get_reloadable_sqlite_layer().unwrap();
568+
566569
use tracing_subscriber::Registry;
567570
use tracing_subscriber::layer::SubscriberExt;
568571
use tracing_subscriber::util::SubscriberInitExt;
@@ -574,6 +577,7 @@ pub fn initialize_logging_with_log_prefix(
574577
std::env::var(env_var).unwrap_or_default() != "1"
575578
}
576579
if let Err(err) = Registry::default()
580+
.with(sqlite_layer)
577581
.with(if is_layer_enabled(DISABLE_OTEL_TRACING) {
578582
Some(otel::tracing_layer())
579583
} else {

hyperactor_telemetry/src/sqlite.rs

Lines changed: 187 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
*/
88

99
use std::collections::HashMap;
10+
use std::fs;
11+
use std::path::PathBuf;
1012
use std::sync::Arc;
1113
use std::sync::Mutex;
1214

@@ -20,11 +22,19 @@ use serde_json::Value as JValue;
2022
use serde_rusqlite::*;
2123
use tracing::Event;
2224
use tracing::Subscriber;
23-
use tracing::level_filters::LevelFilter;
2425
use tracing_subscriber::Layer;
25-
use tracing_subscriber::filter::Targets;
26+
use tracing_subscriber::Registry;
2627
use tracing_subscriber::prelude::*;
28+
use tracing_subscriber::reload;
2729

30+
pub type SqliteReloadHandle = reload::Handle<Option<SqliteLayer>, Registry>;
31+
32+
lazy_static! {
33+
// Reload handle allows us to include a no-op layer during init, but load
34+
// the layer dynamically during tests.
35+
static ref RELOAD_HANDLE: Mutex<Option<SqliteReloadHandle>> =
36+
Mutex::new(None);
37+
}
2838
pub trait TableDef {
2939
fn name(&self) -> &'static str;
3040
fn columns(&self) -> &'static [&'static str];
@@ -224,7 +234,15 @@ macro_rules! insert_event {
224234
impl SqliteLayer {
225235
pub fn new() -> Result<Self> {
226236
let conn = Connection::open_in_memory()?;
237+
Self::setup_connection(conn)
238+
}
239+
240+
pub fn new_with_file(db_path: &str) -> Result<Self> {
241+
let conn = Connection::open(db_path)?;
242+
Self::setup_connection(conn)
243+
}
227244

245+
fn setup_connection(conn: Connection) -> Result<Self> {
228246
for table in ALL_TABLES.iter() {
229247
conn.execute(&table.create_table_stmt, [])?;
230248
}
@@ -326,21 +344,89 @@ fn print_table(conn: &Connection, table_name: TableName) -> Result<()> {
326344
Ok(())
327345
}
328346

329-
pub fn with_tracing_db() -> Arc<Mutex<Connection>> {
330-
let layer = SqliteLayer::new().unwrap();
331-
let conn = layer.connection();
332-
333-
let layer = layer.with_filter(
334-
Targets::new()
335-
.with_default(LevelFilter::TRACE)
336-
.with_targets(vec![
337-
("tokio", LevelFilter::OFF),
338-
("opentelemetry", LevelFilter::OFF),
339-
("runtime", LevelFilter::OFF),
340-
]),
341-
);
342-
tracing_subscriber::registry().with(layer).init();
343-
conn
347+
fn init_tracing_subscriber(layer: SqliteLayer) {
348+
let handle = RELOAD_HANDLE.lock().unwrap();
349+
if let Some(reload_handle) = handle.as_ref() {
350+
let _ = reload_handle.reload(layer);
351+
} else {
352+
tracing_subscriber::registry().with(layer).init();
353+
}
354+
}
355+
356+
// === API ===
357+
358+
// Creates a new reload handler and no-op layer for initialization
359+
pub fn get_reloadable_sqlite_layer() -> Result<reload::Layer<Option<SqliteLayer>, Registry>> {
360+
let (layer, reload_handle) = reload::Layer::new(None);
361+
let mut handle = RELOAD_HANDLE.lock().unwrap();
362+
*handle = Some(reload_handle);
363+
Ok(layer)
364+
}
365+
366+
/// RAII guard for SQLite tracing database
367+
pub struct SqliteTracing {
368+
db_path: Option<PathBuf>,
369+
connection: Arc<Mutex<Connection>>,
370+
}
371+
372+
impl SqliteTracing {
373+
/// Create a new SqliteTracing with a temporary file
374+
pub fn new() -> Result<Self> {
375+
let temp_dir = std::env::temp_dir();
376+
let file_name = format!("hyperactor_trace_{}.db", std::process::id());
377+
let db_path = temp_dir.join(file_name);
378+
379+
let db_path_str = db_path.to_string_lossy();
380+
let layer = SqliteLayer::new_with_file(&db_path_str)?;
381+
let connection = layer.connection();
382+
383+
init_tracing_subscriber(layer);
384+
385+
Ok(Self {
386+
db_path: Some(db_path),
387+
connection,
388+
})
389+
}
390+
391+
/// Create a new SqliteTracing with in-memory database
392+
pub fn new_in_memory() -> Result<Self> {
393+
let layer = SqliteLayer::new()?;
394+
let connection = layer.connection();
395+
396+
init_tracing_subscriber(layer);
397+
398+
Ok(Self {
399+
db_path: None,
400+
connection,
401+
})
402+
}
403+
404+
/// Get the path to the temporary database file (None for in-memory)
405+
pub fn db_path(&self) -> Option<&PathBuf> {
406+
self.db_path.as_ref()
407+
}
408+
409+
/// Get a reference to the database connection
410+
pub fn connection(&self) -> Arc<Mutex<Connection>> {
411+
self.connection.clone()
412+
}
413+
}
414+
415+
impl Drop for SqliteTracing {
416+
fn drop(&mut self) {
417+
// Reset the layer to None
418+
let handle = RELOAD_HANDLE.lock().unwrap();
419+
if let Some(reload_handle) = handle.as_ref() {
420+
let _ = reload_handle.reload(None);
421+
}
422+
423+
// Delete the temporary file if it exists
424+
if let Some(db_path) = &self.db_path {
425+
if db_path.exists() {
426+
let _ = fs::remove_file(db_path);
427+
}
428+
}
429+
}
344430
}
345431

346432
#[cfg(test)]
@@ -350,8 +436,9 @@ mod tests {
350436
use super::*;
351437

352438
#[test]
353-
fn test_sqlite_layer() -> Result<()> {
354-
let conn = with_tracing_db();
439+
fn test_sqlite_tracing_with_file() -> Result<()> {
440+
let tracing = SqliteTracing::new()?;
441+
let conn = tracing.connection();
355442

356443
info!(target:"messages", test_field = "test_value", "Test msg");
357444
info!(target:"log_events", test_field = "test_value", "Test event");
@@ -362,6 +449,87 @@ mod tests {
362449
.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
363450
print_table(&conn.lock().unwrap(), TableName::LogEvents)?;
364451
assert!(count > 0);
452+
453+
// Verify we have a file path
454+
assert!(tracing.db_path().is_some());
455+
let db_path = tracing.db_path().unwrap();
456+
assert!(db_path.exists());
457+
458+
Ok(())
459+
}
460+
461+
#[test]
462+
fn test_sqlite_tracing_in_memory() -> Result<()> {
463+
let tracing = SqliteTracing::new_in_memory()?;
464+
let conn = tracing.connection();
465+
466+
info!(target:"messages", test_field = "test_value", "Test event in memory");
467+
468+
let count: i64 =
469+
conn.lock()
470+
.unwrap()
471+
.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
472+
print_table(&conn.lock().unwrap(), TableName::Messages)?;
473+
assert!(count > 0);
474+
475+
// Verify we don't have a file path for in-memory
476+
assert!(tracing.db_path().is_none());
477+
478+
Ok(())
479+
}
480+
481+
#[test]
482+
fn test_sqlite_tracing_cleanup() -> Result<()> {
483+
let db_path = {
484+
let tracing = SqliteTracing::new()?;
485+
let conn = tracing.connection();
486+
487+
info!(target:"log_events", test_field = "cleanup_test", "Test cleanup event");
488+
489+
let count: i64 =
490+
conn.lock()
491+
.unwrap()
492+
.query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
493+
assert!(count > 0);
494+
495+
tracing.db_path().unwrap().clone()
496+
}; // tracing goes out of scope here, triggering Drop
497+
498+
// File should be cleaned up after Drop
499+
assert!(!db_path.exists());
500+
501+
Ok(())
502+
}
503+
504+
#[test]
505+
fn test_sqlite_tracing_different_targets() -> Result<()> {
506+
let tracing = SqliteTracing::new_in_memory()?;
507+
let conn = tracing.connection();
508+
509+
// Test different event targets
510+
info!(target:"messages", src = "actor1", dest = "actor2", payload = "test_message", "Message event");
511+
info!(target:"actor_lifecycle", actor_id = "123", actor = "TestActor", name = "test", "Lifecycle event");
512+
info!(target:"log_events", test_field = "general_event", "General event");
513+
514+
// Check that events went to the right tables
515+
let message_count: i64 =
516+
conn.lock()
517+
.unwrap()
518+
.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
519+
assert_eq!(message_count, 1);
520+
521+
let lifecycle_count: i64 =
522+
conn.lock()
523+
.unwrap()
524+
.query_row("SELECT COUNT(*) FROM actor_lifecycle", [], |row| row.get(0))?;
525+
assert_eq!(lifecycle_count, 1);
526+
527+
let events_count: i64 =
528+
conn.lock()
529+
.unwrap()
530+
.query_row("SELECT COUNT(*) FROM log_events", [], |row| row.get(0))?;
531+
assert_eq!(events_count, 1);
532+
365533
Ok(())
366534
}
367535
}

monarch_hyperactor/src/telemetry.rs

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::cell::Cell;
1313
use hyperactor::clock::ClockKind;
1414
use hyperactor::clock::RealClock;
1515
use hyperactor::clock::SimClock;
16+
use hyperactor_telemetry::sqlite::SqliteTracing;
1617
use hyperactor_telemetry::swap_telemetry_clock;
1718
use opentelemetry::global;
1819
use opentelemetry::metrics;
@@ -65,7 +66,6 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> {
6566
let file = record.getattr(py, "filename")?;
6667
let file: &str = file.extract(py)?;
6768
let level: i32 = record.getattr(py, "levelno")?.extract(py)?;
68-
6969
// Map level number to level name
7070
match level {
7171
40 | 50 => {
@@ -82,6 +82,7 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> {
8282
match traceback {
8383
Some(traceback) => {
8484
tracing::error!(
85+
target:"log_events",
8586
file = file,
8687
lineno = lineno,
8788
stacktrace = traceback,
@@ -93,10 +94,10 @@ pub fn forward_to_tracing(py: Python, record: PyObject) -> PyResult<()> {
9394
}
9495
}
9596
}
96-
30 => tracing::warn!(file = file, lineno = lineno, message),
97-
20 => tracing::info!(file = file, lineno = lineno, message),
98-
10 => tracing::debug!(file = file, lineno = lineno, message),
99-
_ => tracing::info!(file = file, lineno = lineno, message),
97+
30 => tracing::warn!(target:"log_events", file = file, lineno = lineno, message),
98+
20 => tracing::info!(target:"log_events", file = file, lineno = lineno, message),
99+
10 => tracing::debug!(target:"log_events", file = file, lineno = lineno, message),
100+
_ => tracing::info!(target:"log_events", file = file, lineno = lineno, message),
100101
}
101102
Ok(())
102103
}
@@ -215,6 +216,62 @@ impl PySpan {
215216
}
216217
}
217218

219+
#[pyclass(
220+
subclass,
221+
module = "monarch._rust_bindings.monarch_hyperactor.telemetry"
222+
)]
223+
struct PySqliteTracing {
224+
guard: Option<SqliteTracing>,
225+
}
226+
227+
#[pymethods]
228+
impl PySqliteTracing {
229+
#[new]
230+
#[pyo3(signature = (in_memory = false))]
231+
fn new(in_memory: bool) -> PyResult<Self> {
232+
let guard = if in_memory {
233+
SqliteTracing::new_in_memory()
234+
} else {
235+
SqliteTracing::new()
236+
};
237+
238+
match guard {
239+
Ok(guard) => Ok(Self { guard: Some(guard) }),
240+
Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
241+
"Failed to create SQLite tracing guard: {}",
242+
e
243+
))),
244+
}
245+
}
246+
247+
fn db_path(&self) -> PyResult<Option<String>> {
248+
match &self.guard {
249+
Some(guard) => Ok(guard.db_path().map(|p| p.to_string_lossy().to_string())),
250+
None => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
251+
"Guard has been closed",
252+
)),
253+
}
254+
}
255+
256+
fn __enter__(slf: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
257+
Ok(slf)
258+
}
259+
260+
fn __exit__(
261+
&mut self,
262+
_exc_type: Option<PyObject>,
263+
_exc_value: Option<PyObject>,
264+
_traceback: Option<PyObject>,
265+
) -> PyResult<bool> {
266+
self.guard = None;
267+
Ok(false) // Don't suppress exceptions
268+
}
269+
270+
fn close(&mut self) {
271+
self.guard = None;
272+
}
273+
}
274+
218275
use pyo3::Bound;
219276
use pyo3::types::PyModule;
220277

@@ -267,5 +324,6 @@ pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
267324
module.add_class::<PyCounter>()?;
268325
module.add_class::<PyHistogram>()?;
269326
module.add_class::<PyUpDownCounter>()?;
327+
module.add_class::<PySqliteTracing>()?;
270328
Ok(())
271329
}

0 commit comments

Comments
 (0)