Skip to content

Commit 96ebb43

Browse files
committed
feat: expose Querier class
1 parent 7e198c0 commit 96ebb43

File tree

7 files changed

+55
-35
lines changed

7 files changed

+55
-35
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,6 @@ docs/_build/
7272
.python-version
7373
# stack graph sqlite dbs
7474
*.db
75+
*.sqlite
7576

7677
.mypy_cache/

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ classifiers = [
1010
"Programming Language :: Python :: Implementation :: CPython",
1111
"Programming Language :: Python :: Implementation :: PyPy",
1212
]
13-
version = "0.0.4"
13+
version = "0.0.5"
1414
[tool.maturin]
1515
features = ["pyo3/extension-module"]

src/classes.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ use std::fmt::Display;
22

33
use pyo3::prelude::*;
44

5+
use stack_graphs::storage::SQLiteReader;
56
use tree_sitter_stack_graphs::cli::util::{SourcePosition, SourceSpan};
67

8+
use crate::stack_graphs_wrapper::query_definition;
9+
710
#[pyclass]
811
#[derive(Clone)]
912
pub enum Language {
@@ -24,6 +27,37 @@ pub struct Position {
2427
column: usize,
2528
}
2629

30+
#[pyclass]
31+
pub struct Querier {
32+
db_reader: SQLiteReader,
33+
}
34+
35+
#[pymethods]
36+
impl Querier {
37+
#[new]
38+
pub fn new(db_path: String) -> Self {
39+
println!("Opening database: {}", db_path);
40+
Querier {
41+
db_reader: SQLiteReader::open(db_path).unwrap(),
42+
}
43+
}
44+
45+
pub fn definitions(&mut self, reference: Position) -> PyResult<Vec<Position>> {
46+
let result = query_definition(reference.into(), &mut self.db_reader)?;
47+
48+
let positions: Vec<Position> = result
49+
.into_iter()
50+
.map(|r| r.targets)
51+
.flatten()
52+
.map(|t| t.into())
53+
.collect();
54+
55+
Ok(positions)
56+
}
57+
}
58+
59+
// TODO(@nohehf): Indexer class
60+
2761
#[pymethods]
2862
impl Position {
2963
#[new]

src/lib.rs

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use pyo3::prelude::*;
33
mod classes;
44
mod stack_graphs_wrapper;
55

6-
use classes::{Language, Position};
6+
use classes::{Language, Position, Querier};
77

88
/// Formats the sum of two numbers as string.
99
#[pyfunction]
@@ -14,8 +14,8 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
1414
#[pyfunction]
1515
fn index(paths: Vec<String>, db_path: String, language: Language) -> PyResult<()> {
1616
// TODO(@nohehf): Add a verbose mode to toggle the logs
17-
println!("Indexing paths: {:?}", paths);
18-
println!("Database path: {:?}", db_path);
17+
// println!("Indexing paths: {:?}", paths);
18+
// println!("Database path: {:?}", db_path);
1919

2020
let paths: Vec<std::path::PathBuf> =
2121
paths.iter().map(|p| std::path::PathBuf::from(p)).collect();
@@ -27,32 +27,13 @@ fn index(paths: Vec<String>, db_path: String, language: Language) -> PyResult<()
2727
)?)
2828
}
2929

30-
/// Indexes the given paths into stack graphs, and stores the results in the given database.
31-
#[pyfunction]
32-
fn query_definition(reference: Position, db_path: String) -> PyResult<Vec<Position>> {
33-
println!("Querying reference: {:?}", reference.to_string());
34-
println!("Database path: {:?}", db_path);
35-
36-
let result = stack_graphs_wrapper::query_definition(reference.into(), &db_path)?;
37-
38-
// TODO(@nohehf): Check if we can flatten the results, see the QueryResult struct, we might be loosing some information
39-
let positions: Vec<Position> = result
40-
.into_iter()
41-
.map(|r| r.targets)
42-
.flatten()
43-
.map(|t| t.into())
44-
.collect();
45-
46-
Ok(positions)
47-
}
48-
4930
/// A Python module implemented in Rust.
5031
#[pymodule]
5132
fn stack_graphs_python(_py: Python, m: &PyModule) -> PyResult<()> {
5233
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
5334
m.add_function(wrap_pyfunction!(index, m)?)?;
54-
m.add_function(wrap_pyfunction!(query_definition, m)?)?;
5535
m.add_class::<Position>()?;
5636
m.add_class::<Language>()?;
37+
m.add_class::<Querier>()?;
5738
Ok(())
5839
}

src/stack_graphs_wrapper/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,11 @@ pub fn index(
8383

8484
pub fn query_definition(
8585
reference: SourcePosition,
86-
db_path: &str,
86+
db_reader: &mut SQLiteReader,
8787
) -> Result<Vec<QueryResult>, StackGraphsError> {
88-
let mut db_read = SQLiteReader::open(&db_path).expect("failed to open database");
89-
9088
let reporter = ConsoleReporter::none();
9189

92-
let mut querier = Querier::new(&mut db_read, &reporter);
90+
let mut querier = Querier::new(db_reader, &reporter);
9391

9492
// print_source_position(&reference);
9593

stack_graphs_python.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,8 @@ class Position:
1313

1414
def __init__(self, path: str, line: int, column: int) -> None: ...
1515

16+
class Querier:
17+
def __init__(self, db_path: str) -> None: ...
18+
def definitions(self, reference: Position) -> list[Position]: ...
19+
1620
def index(paths: list[str], db_path: str, language: Language) -> None: ...
17-
def query_definition(reference: Position, db_path: str) -> list[Position]: ...

tests/test.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1+
# TODO(@nohehf): Make this a propper pytest test & run in CI
12
import os
2-
from stack_graphs_python import index, query_definition, Position, Language
3+
from stack_graphs_python import index, Querier, Position, Language
34

45
# index ./js_sample directory
56

67
# convert ./js_sample directory to absolute path
78
dir = os.path.abspath("./tests/js_sample")
8-
db = os.path.abspath("./js_sample.db")
9+
db_path = os.path.abspath("./db.sqlite")
910

1011
print("Indexing directory: ", dir)
11-
print("Database path: ", db)
12+
print("Database path: ", db_path)
1213

13-
index([dir], db, language=Language.Python)
14+
index([dir], db_path, language=Language.JavaScript)
1415

15-
source_reference: Position = Position(path=dir + "/index.js", line=2, column=12)
16+
source_reference = Position(path=dir + "/index.js", line=2, column=12)
1617

1718
print("Querying definition for: ", source_reference.path)
1819

19-
results = query_definition(source_reference, db)
20+
querier = Querier(db_path)
21+
22+
results = querier.definitions(source_reference)
2023

2124
print("Results: ", results)
2225

0 commit comments

Comments
 (0)