|
| 1 | +use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; |
| 2 | +use postgres::Client; |
| 3 | +use postgres_openssl::MakeTlsConnector; |
| 4 | +use rand::{distributions::Alphanumeric, thread_rng, Rng}; |
| 5 | +use std::{borrow::Cow, collections::BTreeMap, fmt}; |
| 6 | + |
| 7 | +use synapse_compress_state::StateGroupEntry; |
| 8 | + |
| 9 | +pub mod map_builder; |
| 10 | + |
| 11 | +pub static DB_URL: &str = "postgresql://synapse_user:synapse_pass@localhost/synapse"; |
| 12 | + |
| 13 | +/// Adds the contents of a state group map to the testing database |
| 14 | +pub fn add_contents_to_database(room_id: &str, state_group_map: &BTreeMap<i64, StateGroupEntry>) { |
| 15 | + // connect to the database |
| 16 | + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); |
| 17 | + builder.set_verify(SslVerifyMode::NONE); |
| 18 | + let connector = MakeTlsConnector::new(builder.build()); |
| 19 | + |
| 20 | + let mut client = Client::connect(DB_URL, connector).unwrap(); |
| 21 | + |
| 22 | + // build up the query |
| 23 | + let mut sql = "".to_string(); |
| 24 | + |
| 25 | + for (sg, entry) in state_group_map { |
| 26 | + // create the entry for state_groups |
| 27 | + sql.push_str(&format!( |
| 28 | + "INSERT INTO state_groups (id, room_id, event_id) VALUES ({},{},{});\n", |
| 29 | + sg, |
| 30 | + PGEscape(room_id), |
| 31 | + PGEscape("left_blank") |
| 32 | + )); |
| 33 | + |
| 34 | + // create the entry in state_group_edges IF exists |
| 35 | + if let Some(prev_sg) = entry.prev_state_group { |
| 36 | + sql.push_str(&format!( |
| 37 | + "INSERT INTO state_group_edges (state_group, prev_state_group) VALUES ({}, {});\n", |
| 38 | + sg, prev_sg |
| 39 | + )); |
| 40 | + } |
| 41 | + |
| 42 | + // write entry for each row in delta |
| 43 | + if !entry.state_map.is_empty() { |
| 44 | + sql.push_str("INSERT INTO state_groups_state (state_group, room_id, type, state_key, event_id) VALUES"); |
| 45 | + |
| 46 | + let mut first = true; |
| 47 | + for ((t, s), e) in entry.state_map.iter() { |
| 48 | + if first { |
| 49 | + sql.push_str(" "); |
| 50 | + first = false; |
| 51 | + } else { |
| 52 | + sql.push_str(" ,"); |
| 53 | + } |
| 54 | + sql.push_str(&format!( |
| 55 | + "({}, {}, {}, {}, {})", |
| 56 | + sg, |
| 57 | + PGEscape(room_id), |
| 58 | + PGEscape(t), |
| 59 | + PGEscape(s), |
| 60 | + PGEscape(e) |
| 61 | + )); |
| 62 | + } |
| 63 | + sql.push_str(";\n"); |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + client.batch_execute(&sql).unwrap(); |
| 68 | +} |
| 69 | + |
| 70 | +// Clears the contents of the testing database |
| 71 | +pub fn empty_database() { |
| 72 | + // connect to the database |
| 73 | + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); |
| 74 | + builder.set_verify(SslVerifyMode::NONE); |
| 75 | + let connector = MakeTlsConnector::new(builder.build()); |
| 76 | + |
| 77 | + let mut client = Client::connect(DB_URL, connector).unwrap(); |
| 78 | + |
| 79 | + // delete all the contents from all three tables |
| 80 | + let sql = r" |
| 81 | + DELETE FROM state_groups; |
| 82 | + DELETE FROM state_group_edges; |
| 83 | + DELETE FROM state_groups_state; |
| 84 | + "; |
| 85 | + |
| 86 | + client.batch_execute(sql).unwrap(); |
| 87 | +} |
| 88 | + |
| 89 | +/// Safely escape the strings in sql queries |
| 90 | +struct PGEscape<'a>(pub &'a str); |
| 91 | + |
| 92 | +impl<'a> fmt::Display for PGEscape<'a> { |
| 93 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 94 | + let mut delim = Cow::from("$$"); |
| 95 | + while self.0.contains(&delim as &str) { |
| 96 | + let s: String = thread_rng() |
| 97 | + .sample_iter(&Alphanumeric) |
| 98 | + .take(10) |
| 99 | + .map(char::from) |
| 100 | + .collect(); |
| 101 | + |
| 102 | + delim = format!("${}$", s).into(); |
| 103 | + } |
| 104 | + |
| 105 | + write!(f, "{}{}{}", delim, self.0, delim) |
| 106 | + } |
| 107 | +} |
0 commit comments