|
4 | 4 | #![forbid(unsafe_code)] |
5 | 5 | #![forbid(where_clauses_object_safety)] |
6 | 6 |
|
| 7 | +use std::collections::HashMap; |
| 8 | + |
7 | 9 | use huff_utils::{ |
8 | 10 | ast::*, |
9 | 11 | error::*, |
@@ -136,6 +138,8 @@ impl Parser { |
136 | 138 | } |
137 | 139 | } |
138 | 140 |
|
| 141 | + validate_macros(&contract)?; |
| 142 | + |
139 | 143 | Ok(contract) |
140 | 144 | } |
141 | 145 |
|
@@ -521,63 +525,6 @@ impl Parser { |
521 | 525 |
|
522 | 526 | let macro_statements: Vec<Statement> = self.parse_body()?; |
523 | 527 |
|
524 | | - if outlined { |
525 | | - let (body_statements_take, body_statements_return) = |
526 | | - macro_statements.iter().fold((0i16, 0i16), |acc, st| { |
527 | | - let (statement_takes, statement_returns) = match st.ty { |
528 | | - StatementType::Literal(_) | |
529 | | - StatementType::Constant(_) | |
530 | | - StatementType::BuiltinFunctionCall(_) | |
531 | | - StatementType::ArgCall(_) | |
532 | | - StatementType::LabelCall(_) => (0i8, 1i8), |
533 | | - StatementType::Opcode(opcode) => { |
534 | | - if opcode.is_value_push() { |
535 | | - (0i8, 0i8) |
536 | | - } else { |
537 | | - let stack_changes = opcode.stack_changes(); |
538 | | - (stack_changes.0 as i8, stack_changes.1 as i8) |
539 | | - } |
540 | | - } |
541 | | - StatementType::Label(_) => (0i8, 0i8), |
542 | | - StatementType::MacroInvocation(_) => { |
543 | | - todo!() |
544 | | - } |
545 | | - StatementType::Code(_) => { |
546 | | - todo!("should throw error") |
547 | | - } |
548 | | - }; |
549 | | - |
550 | | - // acc.1 is always non negative |
551 | | - // acc.0 is always non positive |
552 | | - let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 { |
553 | | - (acc.0 + acc.1 - statement_takes as i16, statement_returns as i16) |
554 | | - } else { |
555 | | - (acc.0, acc.1 - statement_takes as i16 + statement_returns as i16) |
556 | | - }; |
557 | | - (stack_takes, stack_returns) |
558 | | - }); |
559 | | - if body_statements_take.abs() != macro_takes as i16 { |
560 | | - return Err(ParserError { |
561 | | - kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), |
562 | | - hint: Some(format!( |
563 | | - "Fn {macro_name} specified to take {macro_takes} elements from the stack, but it takes {}", |
564 | | - body_statements_take.abs() |
565 | | - )), |
566 | | - spans: AstSpan(self.spans.clone()), |
567 | | - }); |
568 | | - } |
569 | | - if body_statements_return != macro_returns as i16 { |
570 | | - return Err(ParserError { |
571 | | - kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), |
572 | | - hint: Some(format!( |
573 | | - "Fn {macro_name} specified to return {macro_returns} elements to the stack, but it returns {}", |
574 | | - body_statements_return |
575 | | - )), |
576 | | - spans: AstSpan(self.spans.clone()), |
577 | | - }); |
578 | | - } |
579 | | - } |
580 | | - |
581 | 528 | Ok(MacroDefinition::new( |
582 | 529 | macro_name, |
583 | 530 | decorator, |
@@ -1336,3 +1283,92 @@ impl Parser { |
1336 | 1283 | } |
1337 | 1284 | } |
1338 | 1285 | } |
| 1286 | + |
| 1287 | +/// Function used to evaluate macro statements. Returns number of elements taken from the stack and |
| 1288 | +/// returned to the stack |
| 1289 | +pub fn evaluate_macro( |
| 1290 | + _macro_name: &str, |
| 1291 | + _macros: &[MacroDefinition], |
| 1292 | + _evaluated_macros: &mut HashMap<String, (i16, i16)>, |
| 1293 | +) -> Result<(i16, i16), ParserError> { |
| 1294 | + if _evaluated_macros.contains_key(_macro_name) { |
| 1295 | + return Ok(*_evaluated_macros.get(_macro_name).unwrap()) |
| 1296 | + } |
| 1297 | + |
| 1298 | + let _macro = _macros.iter().find(|m| m.name.as_str() == _macro_name).unwrap(); |
| 1299 | + let (body_statements_take, body_statements_return) = |
| 1300 | + _macro.statements.iter().fold((0i16, 0i16), |acc, st| { |
| 1301 | + let (statement_takes, statement_returns) = match &st.ty { |
| 1302 | + StatementType::Literal(_) | |
| 1303 | + StatementType::Constant(_) | |
| 1304 | + StatementType::BuiltinFunctionCall(_) | |
| 1305 | + StatementType::ArgCall(_) => (0i8, 1i8), |
| 1306 | + StatementType::LabelCall(_) => (0i8, 1i8), |
| 1307 | + StatementType::Opcode(opcode) => { |
| 1308 | + if opcode.is_value_push() { |
| 1309 | + (0i8, 0i8) |
| 1310 | + } else { |
| 1311 | + let stack_changes = opcode.stack_changes(); |
| 1312 | + (stack_changes.0 as i8, stack_changes.1 as i8) |
| 1313 | + } |
| 1314 | + } |
| 1315 | + StatementType::Label(_) => (0i8, 0i8), |
| 1316 | + StatementType::MacroInvocation(MacroInvocation { |
| 1317 | + macro_name, |
| 1318 | + args: _, |
| 1319 | + span: _, |
| 1320 | + }) => { |
| 1321 | + let (takes, returns) = |
| 1322 | + evaluate_macro(macro_name, _macros, _evaluated_macros).unwrap(); |
| 1323 | + (takes.abs() as i8, returns as i8) |
| 1324 | + } |
| 1325 | + StatementType::Code(_) => { |
| 1326 | + todo!("should throw error") |
| 1327 | + } |
| 1328 | + }; |
| 1329 | + |
| 1330 | + // acc.1 is always non negative |
| 1331 | + // acc.0 is always non positive |
| 1332 | + let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 { |
| 1333 | + (acc.0 + acc.1 - statement_takes as i16, statement_returns as i16) |
| 1334 | + } else { |
| 1335 | + (acc.0, acc.1 - statement_takes as i16 + statement_returns as i16) |
| 1336 | + }; |
| 1337 | + (stack_takes, stack_returns) |
| 1338 | + }); |
| 1339 | + |
| 1340 | + _evaluated_macros.insert(_macro.name.clone(), (body_statements_take, body_statements_return)); |
| 1341 | + Ok((body_statements_take, body_statements_return)) |
| 1342 | +} |
| 1343 | + |
| 1344 | +/// Function used to validate takes and returns of outlined macros in the contract |
| 1345 | +pub fn validate_macros(contract: &Contract) -> Result<(), ParserError> { |
| 1346 | + let mut evaluated_macros = HashMap::new(); |
| 1347 | + for _macro in contract.macros.iter().filter(|m| m.outlined) { |
| 1348 | + let (body_statements_take, body_statements_return) = |
| 1349 | + evaluate_macro(&_macro.name, &contract.macros, &mut evaluated_macros)?; |
| 1350 | + if body_statements_take.abs() != _macro.takes as i16 { |
| 1351 | + return Err(ParserError { |
| 1352 | + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), |
| 1353 | + hint: Some(format!( |
| 1354 | + "Fn {} specified to take {} elements from the stack, but it takes {}", |
| 1355 | + _macro.name, |
| 1356 | + _macro.takes, |
| 1357 | + body_statements_take.abs() |
| 1358 | + )), |
| 1359 | + spans: _macro.span.clone(), |
| 1360 | + }) |
| 1361 | + } |
| 1362 | + if body_statements_return != _macro.returns as i16 { |
| 1363 | + return Err(ParserError { |
| 1364 | + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), |
| 1365 | + hint: Some(format!( |
| 1366 | + "Fn {} specified to return {} elements to the stack, but it returns {}", |
| 1367 | + _macro.name, _macro.returns, body_statements_return |
| 1368 | + )), |
| 1369 | + spans: _macro.span.clone(), |
| 1370 | + }) |
| 1371 | + } |
| 1372 | + } |
| 1373 | + Ok(()) |
| 1374 | +} |
0 commit comments