Skip to content
This repository was archived by the owner on Mar 25, 2025. It is now read-only.

Commit 5952a87

Browse files
Support for Breakpoint block (nrn_cur) for code generation (#645)
* Support for Breakpoint block (nrn_cur) for code generation * similar to DERIVATIVE (nrn_state), handle BREAKPOINT (nrn_cur) blocks with AST level transformation * Move common code from CodegenCVisitor to CodegenInfo * Add tests fixes #644 Co-authored-by: George Mitenkov <georgemitenk0v@gmail.com>
1 parent 64e8cee commit 5952a87

18 files changed

+987
-129
lines changed

src/codegen/codegen_acc_visitor.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ void CodegenAccVisitor::print_net_init_acc_serial_annotation_block_end() {
200200
}
201201

202202
void CodegenAccVisitor::print_nrn_cur_matrix_shadow_update() {
203-
auto rhs_op = operator_for_rhs();
204-
auto d_op = operator_for_d();
203+
const auto& rhs_op = info.operator_for_rhs();
204+
const auto& d_op = info.operator_for_d();
205205
print_atomic_reduction_pragma();
206206
printer->add_line("vec_rhs[node_id] {} rhs;"_format(rhs_op));
207207
print_atomic_reduction_pragma();
@@ -213,8 +213,8 @@ void CodegenAccVisitor::print_fast_imem_calculation() {
213213
return;
214214
}
215215

216-
auto rhs_op = operator_for_rhs();
217-
auto d_op = operator_for_d();
216+
const auto& rhs_op = info.operator_for_rhs();
217+
const auto& d_op = info.operator_for_d();
218218
printer->start_block("if (nt->nrn_fast_imem)");
219219
print_atomic_reduction_pragma();
220220
printer->add_line("nt->nrn_fast_imem->nrn_sav_rhs[node_id] {} rhs;"_format(rhs_op));

src/codegen/codegen_c_visitor.cpp

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -325,37 +325,6 @@ void CodegenCVisitor::visit_update_dt(const ast::UpdateDt& node) {
325325
/* Common helper routines */
326326
/****************************************************************************************/
327327

328-
329-
/**
330-
* \details Certain statements like unit, comment, solve can/need to be skipped
331-
* during code generation. Note that solve block is wrapped in expression
332-
* statement and hence we have to check inner expression. It's also true
333-
* for the initial block defined inside net receive block.
334-
*/
335-
bool CodegenCVisitor::statement_to_skip(const Statement& node) const {
336-
// clang-format off
337-
if (node.is_unit_state()
338-
|| node.is_line_comment()
339-
|| node.is_block_comment()
340-
|| node.is_solve_block()
341-
|| node.is_conductance_hint()
342-
|| node.is_table_statement()) {
343-
return true;
344-
}
345-
// clang-format on
346-
if (node.is_expression_statement()) {
347-
auto expression = dynamic_cast<const ExpressionStatement*>(&node)->get_expression();
348-
if (expression->is_solve_block()) {
349-
return true;
350-
}
351-
if (expression->is_initial_block()) {
352-
return true;
353-
}
354-
}
355-
return false;
356-
}
357-
358-
359328
/**
360329
* \details When floating point data type is not default (i.e. double) then we
361330
* have to copy old array to new type (for range variables).
@@ -974,8 +943,8 @@ void CodegenCVisitor::print_nrn_cur_matrix_shadow_update() {
974943
printer->add_line("shadow_rhs[id] = rhs;");
975944
printer->add_line("shadow_d[id] = g;");
976945
} else {
977-
auto rhs_op = operator_for_rhs();
978-
auto d_op = operator_for_d();
946+
const auto& rhs_op = info.operator_for_rhs();
947+
const auto& d_op = info.operator_for_d();
979948
print_atomic_reduction_pragma();
980949
printer->add_line("vec_rhs[node_id] {} rhs;"_format(rhs_op));
981950
print_atomic_reduction_pragma();
@@ -986,8 +955,8 @@ void CodegenCVisitor::print_nrn_cur_matrix_shadow_update() {
986955

987956

988957
void CodegenCVisitor::print_nrn_cur_matrix_shadow_reduction() {
989-
auto rhs_op = operator_for_rhs();
990-
auto d_op = operator_for_d();
958+
const auto& rhs_op = info.operator_for_rhs();
959+
const auto& d_op = info.operator_for_d();
991960
if (channel_task_dependency_enabled()) {
992961
auto rhs = get_variable_name("ml_rhs");
993962
auto d = get_variable_name("ml_d");
@@ -1167,7 +1136,7 @@ void CodegenCVisitor::print_statement_block(const ast::StatementBlock& node,
11671136

11681137
auto statements = node.get_statements();
11691138
for (const auto& statement: statements) {
1170-
if (statement_to_skip(*statement)) {
1139+
if (info.statement_to_skip(*statement)) {
11711140
continue;
11721141
}
11731142
/// not necessary to add indent for verbatim block (pretty-printing)
@@ -4337,8 +4306,8 @@ void CodegenCVisitor::print_fast_imem_calculation() {
43374306
return;
43384307
}
43394308
std::string rhs, d;
4340-
auto rhs_op = operator_for_rhs();
4341-
auto d_op = operator_for_d();
4309+
const auto& rhs_op = info.operator_for_rhs();
4310+
const auto& d_op = info.operator_for_d();
43424311
if (channel_task_dependency_enabled()) {
43434312
rhs = get_variable_name("ml_rhs");
43444313
d = get_variable_name("ml_d");

src/codegen/codegen_c_visitor.hpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -218,23 +218,6 @@ class CodegenCVisitor: public visitor::ConstAstVisitor {
218218
return "\"" + text + "\"";
219219
}
220220

221-
222-
/**
223-
* Operator for rhs vector update (matrix update)
224-
*/
225-
std::string operator_for_rhs() const noexcept {
226-
return info.electrode_current ? "+=" : "-=";
227-
}
228-
229-
230-
/**
231-
* Operator for diagonal vector update (matrix update)
232-
*/
233-
std::string operator_for_d() const noexcept {
234-
return info.electrode_current ? "-=" : "+=";
235-
}
236-
237-
238221
/**
239222
* Data type for the local variables
240223
*/

src/codegen/codegen_cuda_visitor.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ void CodegenCudaVisitor::print_device_method_annotation() {
9696

9797

9898
void CodegenCudaVisitor::print_nrn_cur_matrix_shadow_update() {
99-
auto rhs_op = operator_for_rhs();
100-
auto d_op = operator_for_d();
99+
auto rhs_op = info.operator_for_rhs();
100+
auto d_op = info.operator_for_d();
101101
stringutils::remove_character(rhs_op, '=');
102102
stringutils::remove_character(d_op, '=');
103103
print_atomic_op("vec_rhs[node_id]", rhs_op, "rhs");
@@ -109,8 +109,8 @@ void CodegenCudaVisitor::print_fast_imem_calculation() {
109109
return;
110110
}
111111

112-
auto rhs_op = operator_for_rhs();
113-
auto d_op = operator_for_d();
112+
auto rhs_op = info.operator_for_rhs();
113+
auto d_op = info.operator_for_d();
114114
stringutils::remove_character(rhs_op, '=');
115115
stringutils::remove_character(d_op, '=');
116116
printer->start_block("if (nt->nrn_fast_imem)");

src/codegen/codegen_driver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ bool CodegenDriver::prepare_mod(std::shared_ptr<ast::Program> node, const std::s
179179
/// that old symbols (e.g. prime variables) are not lost
180180
update_symtab = true;
181181

182-
if (cfg.nmodl_inline) {
182+
if (cfg.nmodl_inline || cfg.llvm_ir) {
183183
logger->info("Running nmodl inline visitor");
184184
InlineVisitor().visit_program(*node);
185185
ast_to_nmodl(*node, filepath("inline", "mod"));

src/codegen/codegen_info.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,5 +404,34 @@ void CodegenInfo::get_float_variables() {
404404
}
405405
}
406406

407+
/**
408+
* \details Certain statements like unit, comment, solve can/need to be skipped
409+
* during code generation. Note that solve block is wrapped in expression
410+
* statement and hence we have to check inner expression. It's also true
411+
* for the initial block defined inside net receive block.
412+
*/
413+
bool CodegenInfo::statement_to_skip(const ast::Statement& node) const {
414+
// clang-format off
415+
if (node.is_unit_state()
416+
|| node.is_line_comment()
417+
|| node.is_block_comment()
418+
|| node.is_solve_block()
419+
|| node.is_conductance_hint()
420+
|| node.is_table_statement()) {
421+
return true;
422+
}
423+
// clang-format on
424+
if (node.is_expression_statement()) {
425+
auto expression = dynamic_cast<const ast::ExpressionStatement*>(&node)->get_expression();
426+
if (expression->is_solve_block()) {
427+
return true;
428+
}
429+
if (expression->is_initial_block()) {
430+
return true;
431+
}
432+
}
433+
return false;
434+
}
435+
407436
} // namespace codegen
408437
} // namespace nmodl

src/codegen/codegen_info.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,29 @@ struct CodegenInfo {
590590
}
591591

592592

593+
/**
594+
* Operator for rhs vector update (matrix update)
595+
*
596+
* Note that we only rely on following two syntax for
597+
* increment and decrement. Code generation backends
598+
* are relying on this convention.
599+
*/
600+
std::string operator_for_rhs() const noexcept {
601+
return electrode_current ? "+=" : "-=";
602+
}
603+
604+
605+
/**
606+
* Operator for diagonal vector update (matrix update)
607+
*
608+
* Note that we only rely on following two syntax for
609+
* increment and decrement. Code generation backends
610+
* are relying on this convention.
611+
*/
612+
std::string operator_for_d() const noexcept {
613+
return electrode_current ? "-=" : "+=";
614+
}
615+
593616
/**
594617
* Check if net_receive function is required
595618
*/
@@ -657,6 +680,13 @@ struct CodegenInfo {
657680
* \return A \c vector of \c float variables
658681
*/
659682
void get_float_variables();
683+
684+
/**
685+
* Check if statement should be skipped for code generation
686+
* @param node Statement to be checked for code generation
687+
* @return True if statement should be skipped otherwise false
688+
*/
689+
bool statement_to_skip(const ast::Statement& node) const;
660690
};
661691

662692
/** @} */ // end of codegen_backends

src/codegen/codegen_ispc_visitor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ void CodegenIspcVisitor::print_atomic_op(const std::string& lhs,
248248

249249

250250
void CodegenIspcVisitor::print_nrn_cur_matrix_shadow_reduction() {
251-
auto rhs_op = operator_for_rhs();
252-
auto d_op = operator_for_d();
251+
const auto& rhs_op = info.operator_for_rhs();
252+
const auto& d_op = info.operator_for_d();
253253
if (info.point_process) {
254254
printer->add_line("uniform int node_id = node_index[id];");
255255
printer->add_line("vec_rhs[node_id] {} shadow_rhs[id];"_format(rhs_op));

src/codegen/codegen_naming.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ static constexpr char NTHREAD_RHS_SHADOW[] = "_shadow_rhs";
9292
/// shadow d variable in neuron thread structure
9393
static constexpr char NTHREAD_D_SHADOW[] = "_shadow_d";
9494

95+
/// rhs variable in neuron thread structure
96+
static constexpr char NTHREAD_RHS[] = "vec_rhs";
97+
98+
/// d variable in neuron thread structure
99+
static constexpr char NTHREAD_D[] = "vec_d";
100+
95101
/// t variable in neuron thread structure
96102
static constexpr char NTHREAD_T_VARIABLE[] = "t";
97103

0 commit comments

Comments
 (0)