Skip to content

Commit c720d47

Browse files
committed
Add ct_python macro.
1 parent 48f149a commit c720d47

File tree

6 files changed

+171
-52
lines changed

6 files changed

+171
-52
lines changed

macros/examples/example.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#![feature(proc_macro_hygiene)]
2+
3+
use inline_python_macros::ct_python;
4+
5+
static DIRECTIONS: [(f64, f64); 32] = ct_python! {
6+
from math import sin, cos, tau
7+
n = 32
8+
print("[")
9+
for i in range(n):
10+
x = cos(i / n * tau)
11+
y = sin(i / n * tau)
12+
print(f"({x}, {y}),")
13+
print("]")
14+
};
15+
16+
fn main() {
17+
dbg!(&DIRECTIONS);
18+
}

macros/src/embed_python.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub struct EmbedPython {
77
pub variables: BTreeMap<String, Ident>,
88
pub first_indent: Option<usize>,
99
pub loc: LineColumn,
10+
pub compile_time: bool,
1011
}
1112

1213
impl EmbedPython {
@@ -16,6 +17,7 @@ impl EmbedPython {
1617
variables: BTreeMap::new(),
1718
loc: LineColumn { line: 1, column: 0 },
1819
first_indent: None,
20+
compile_time: false,
1921
}
2022
}
2123

@@ -67,7 +69,7 @@ impl EmbedPython {
6769
self.loc.column += end.len();
6870
}
6971
TokenTree::Punct(x) => {
70-
if x.as_char() == '\'' && x.spacing() == Spacing::Joint {
72+
if !self.compile_time && x.as_char() == '\'' && x.spacing() == Spacing::Joint {
7173
let name = if let Some(TokenTree::Ident(name)) = tokens.next() {
7274
name
7375
} else {

macros/src/error.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use proc_macro2::{Span, TokenStream};
2+
use pyo3::type_object::PyTypeObject;
3+
use pyo3::{PyAny, AsPyRef, PyErr, PyResult, Python, ToPyObject};
4+
5+
/// Format a nice error message for a python compilation error.
6+
pub fn emit_compile_error_msg(py: Python, error: PyErr, tokens: TokenStream) {
7+
let value = error.to_object(py);
8+
9+
if value.is_none() {
10+
Span::call_site()
11+
.unwrap()
12+
.error(format!("python: {}", error.ptype.as_ref(py).name()))
13+
.emit();
14+
return;
15+
}
16+
17+
if error.matches(py, pyo3::exceptions::SyntaxError::type_object()) {
18+
let line: Option<usize> = value.getattr(py, "lineno").ok().and_then(|x| x.extract(py).ok());
19+
let msg: Option<String> = value.getattr(py, "msg").ok().and_then(|x| x.extract(py).ok());
20+
if let (Some(line), Some(msg)) = (line, msg) {
21+
if let Some(span) = span_for_line(tokens.clone(), line) {
22+
span.unwrap().error(format!("python: {}", msg)).emit();
23+
return;
24+
}
25+
}
26+
}
27+
28+
if let Some(tb) = &error.ptraceback {
29+
if let Ok((file, line)) = get_traceback_info(tb.as_ref(py)) {
30+
if file == Span::call_site().unwrap().source_file().path().to_string_lossy() {
31+
if let Ok(msg) = value.as_ref(py).str() {
32+
if let Some(span) = span_for_line(tokens, line) {
33+
span.unwrap().error(format!("python: {}", msg)).emit();
34+
return;
35+
}
36+
}
37+
}
38+
}
39+
}
40+
41+
Span::call_site()
42+
.unwrap()
43+
.error(format!("python: {}", value.as_ref(py).str().unwrap()))
44+
.emit();
45+
}
46+
47+
fn get_traceback_info(tb: &PyAny) -> PyResult<(String, usize)> {
48+
let frame = tb.getattr("tb_frame")?;
49+
let code = frame.getattr("f_code")?;
50+
let file: String = code.getattr("co_filename")?.extract()?;
51+
let line: usize = frame.getattr("f_lineno")?.extract()?;
52+
Ok((file, line))
53+
}
54+
55+
/// Get a span for a specific line of input from a TokenStream.
56+
fn span_for_line(input: TokenStream, line: usize) -> Option<Span> {
57+
let mut spans = input
58+
.into_iter()
59+
.map(|x| x.span().unwrap())
60+
.skip_while(|span| span.start().line < line)
61+
.take_while(|span| span.start().line == line);
62+
63+
let mut result = spans.next()?;
64+
for span in spans {
65+
result = match result.join(span) {
66+
None => return Some(Span::from(result)),
67+
Some(span) => span,
68+
}
69+
}
70+
71+
Some(Span::from(result))
72+
}

macros/src/lib.rs

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ extern crate proc_macro;
66
use self::embed_python::EmbedPython;
77
use proc_macro::TokenStream as TokenStream1;
88
use proc_macro2::{Literal, Span, TokenStream};
9-
use pyo3::{ffi, types::PyBytes, AsPyPointer, FromPyPointer, PyErr, PyObject, Python, ToPyObject};
9+
use pyo3::{ffi, types::PyBytes, AsPyPointer, FromPyPointer, PyObject, Python};
1010
use quote::quote;
1111
use std::ffi::CString;
1212

1313
mod embed_python;
14+
mod error;
15+
mod run;
1416

1517
fn python_impl(input: TokenStream) -> Result<TokenStream, ()> {
1618
let tokens = input.clone();
@@ -33,7 +35,7 @@ fn python_impl(input: TokenStream) -> Result<TokenStream, ()> {
3335
let py = gil.python();
3436

3537
let code = PyObject::from_owned_ptr_or_err(py, ffi::Py_CompileString(python.as_ptr(), filename.as_ptr(), ffi::Py_file_input))
36-
.map_err(|err| emit_compile_error_msg(py, err, tokens))?;
38+
.map_err(|err| error::emit_compile_error_msg(py, err, tokens))?;
3739

3840
Literal::byte_string(
3941
PyBytes::from_owned_ptr_or_err(py, ffi::PyMarshal_WriteObjectToString(code.as_ptr(), pyo3::marshal::VERSION))
@@ -59,6 +61,33 @@ fn python_impl(input: TokenStream) -> Result<TokenStream, ()> {
5961
})
6062
}
6163

64+
fn ct_python_impl(input: TokenStream) -> Result<TokenStream, ()> {
65+
let tokens = input.clone();
66+
67+
let filename = Span::call_site().unwrap().source_file().path().to_string_lossy().into_owned();
68+
69+
let mut x = EmbedPython::new();
70+
71+
x.compile_time = true;
72+
73+
x.add(input)?;
74+
75+
let EmbedPython { python, .. } = x;
76+
77+
let python = CString::new(python).unwrap();
78+
let filename = CString::new(filename).unwrap();
79+
80+
let gil = Python::acquire_gil();
81+
let py = gil.python();
82+
83+
let code = unsafe {
84+
PyObject::from_owned_ptr_or_err(py, ffi::Py_CompileString(python.as_ptr(), filename.as_ptr(), ffi::Py_file_input))
85+
.map_err(|err| error::emit_compile_error_msg(py, err, tokens.clone()))?
86+
};
87+
88+
run::run_ct_python(py, code, tokens)
89+
}
90+
6291
fn check_no_attribute(input: TokenStream) -> Result<(), ()> {
6392
let mut input = input.into_iter();
6493
if let Some(token) = input.next() {
@@ -88,53 +117,10 @@ pub fn python(input: TokenStream1) -> TokenStream1 {
88117
})
89118
}
90119

91-
/// Format a nice error message for a python compilation error.
92-
fn emit_compile_error_msg(py: Python, error: PyErr, tokens: TokenStream) {
93-
use pyo3::type_object::PyTypeObject;
94-
use pyo3::AsPyRef;
95-
96-
let value = error.to_object(py);
97-
98-
if value.is_none() {
99-
Span::call_site()
100-
.unwrap()
101-
.error(format!("python: {}", error.ptype.as_ref(py).name()))
102-
.emit();
103-
return;
104-
}
105-
106-
if error.matches(py, pyo3::exceptions::SyntaxError::type_object()) {
107-
let line: Option<usize> = value.getattr(py, "lineno").ok().and_then(|x| x.extract(py).ok());
108-
let msg: Option<String> = value.getattr(py, "msg").ok().and_then(|x| x.extract(py).ok());
109-
if let (Some(line), Some(msg)) = (line, msg) {
110-
if let Some(span) = span_for_line(tokens, line) {
111-
span.unwrap().error(format!("python: {}", msg)).emit();
112-
return;
113-
}
114-
}
115-
}
116-
117-
Span::call_site()
118-
.unwrap()
119-
.error(format!("python: {}", value.as_ref(py).str().unwrap()))
120-
.emit();
121-
}
122-
123-
/// Get a span for a specific line of input from a TokenStream.
124-
fn span_for_line(input: TokenStream, line: usize) -> Option<Span> {
125-
let mut spans = input
126-
.into_iter()
127-
.map(|x| x.span().unwrap())
128-
.skip_while(|span| span.start().line < line)
129-
.take_while(|span| span.start().line == line);
130-
131-
let mut result = spans.next()?;
132-
for span in spans {
133-
result = match result.join(span) {
134-
None => return Some(Span::from(result)),
135-
Some(span) => span,
136-
}
137-
}
138-
139-
Some(Span::from(result))
120+
#[proc_macro]
121+
pub fn ct_python(input: TokenStream1) -> TokenStream1 {
122+
TokenStream1::from(match ct_python_impl(TokenStream::from(input)) {
123+
Ok(tokens) => tokens,
124+
Err(()) => quote!(unimplemented!()).into()
125+
})
140126
}

macros/src/run.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
use crate::error::emit_compile_error_msg;
2+
use proc_macro2::{Span, TokenStream};
3+
use pyo3::{ffi, AsPyPointer, PyObject, PyResult, Python};
4+
use std::str::FromStr;
5+
6+
fn run_and_capture(py: Python, code: PyObject) -> PyResult<String> {
7+
let globals = py.import("__main__")?.dict().copy()?;
8+
9+
let sys = py.import("sys")?;
10+
let io = py.import("io")?;
11+
12+
let stdout = io.call0("StringIO")?;
13+
let original_stdout = sys.dict().get_item("stdout");
14+
sys.dict().set_item("stdout", stdout)?;
15+
16+
let result =
17+
unsafe { PyObject::from_owned_ptr_or_err(py, ffi::PyEval_EvalCode(code.as_ptr(), globals.as_ptr(), std::ptr::null_mut())) };
18+
19+
sys.dict().set_item("stdout", original_stdout)?;
20+
21+
result?;
22+
23+
stdout.call_method0("getvalue")?.extract()
24+
}
25+
26+
pub fn run_ct_python(py: Python, code: PyObject, tokens: TokenStream) -> Result<TokenStream, ()> {
27+
let output = run_and_capture(py, code).map_err(|err| emit_compile_error_msg(py, err, tokens))?;
28+
29+
Ok(TokenStream::from_str(&output).map_err(|e| {
30+
Span::call_site()
31+
.unwrap()
32+
.error(format!("Unable to parse output of ct_python!{{}} script: {:?}", e))
33+
.emit()
34+
})?)
35+
}

src/flush.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// extern "C" {
2+
// fn fflush(_: *mut std::ffi::c_void);
3+
// static mut stdout: *mut std::ffi::c_void;
4+
// }
5+
6+
pub fn flush() {}

0 commit comments

Comments
 (0)